diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index f25738f546..f3ea387118 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -377,6 +377,7 @@ explorer: dynamic_timeout: enable: false ratio: 3.0 + runner_state_report_interval: 0 ``` - `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique. @@ -397,6 +398,7 @@ explorer: - `dynamic_timeout`: [Experimental] Configurations for dynamic timeout mechanism, which adjusts the timeout for each task based on the average time taken for successful tasks. - `enable`: Whether to enable dynamic timeout. Default is `false`. - `ratio`: The timeout for each task is dynamically set to `average_time_per_success_task * ratio`. Default is `3.0`. +- `runner_state_report_interval`: Workflow runner report interval (in seconds). If set to a value greater than `0`, the workflow runner will periodically report its status to the main explorer process and print it in the command line for monitoring. Default is `0`, meaning this feature is disabled. If you want to use this feature, it is recommended to set it to `10` seconds or longer to minimize performance impact. --- diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index c6dc344007..9ebab50322 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -368,26 +368,34 @@ explorer: tensor_parallel_size: 1 eval_interval: 100 eval_on_startup: True + over_rollout: + ratio: 0.0 + wait_after_min: 30.0 + dynamic_timeout: + enable: false + ratio: 3.0 + runner_state_report_interval: 0 ``` - `name`: explorer 的名称。该名称将用作 Ray actor 的名称,因此必须唯一。 -- `runner_per_model`: 每个 rollout 模型的并行工作流执行器数量。 -- `max_timeout`: 工作流完成的最大时间(秒)。 -- `max_retry_times`: 工作流的最大重试次数。 -- `env_vars`: 为每个工作流执行器设置的环境变量。 +- `runner_per_model`: 每个推理引擎实例所服务的 WorkflowRunner 数量。 +- `max_timeout`: 等待 Workflow 完成的最大时间(秒)。 +- `max_retry_times`: Workflow 失败或超时情况下的最大重试次数。 +- `env_vars`: 为每个 WorkflowRunner 设置的环境变量。 - `rollout_model.engine_type`: 推理引擎类型。支持 `vllm_async` 和 `vllm`,二者的含义相同,都使用了异步引擎。后续版本会只保留 `vllm`。 -- `rollout_model.engine_num`: 推理引擎数量。 -- `rollout_model.tensor_parallel_size`: 张量并行度。 -- `rollout_model.enable_history`: 是否启用模型调用历史记录。若设为 `True`,模型包装器会自动记录模型调用返回的 experience。请定期通过 `extract_experience_from_history` 提取历史,以避免内存溢出。默认为 `False`。 -- `auxiliary_models`: 用于自定义工作流的额外模型。 -- `eval_interval`: 模型评估的间隔(以步为单位)。 -- `eval_on_startup`: 是否在启动时评估模型。更准确地说,是在第 0 步使用原始模型评估,因此重启时不会触发。 +- `rollout_model.engine_num`: 推理引擎实例的数量。 +- `rollout_model.tensor_parallel_size`: 每个实例的张量并行度。 +- `rollout_model.enable_history`: 是否启用模型调用历史记录功能。若设为 `True`,模型会自动记录调用返回的 experience。请定期通过 `extract_experience_from_history` 提取历史,以避免内存溢出。默认为 `False`。 +- `auxiliary_models`: 用于自定义工作流的辅助模型。 +- `eval_interval`: 模型评估的间隔(以 step 为单位)。 +- `eval_on_startup`: 是否在启动时评估模型。更准确地说,是在第 0 步使用原始模型评估,因此中断训练后重启时不会触发该行为。 - `over_rollout`: [实验性] 超量 rollout 机制的配置,允许 explorer 在每个步骤中使用少于完整批次大小的任务继续进行。这在某些任务显著耗时较长的场景中能有效地提高吞吐量。仅当使用动态同步(`synchronizer.sync_style` 不是 `fixed`)时适用。 - `ratio`: explorer 在每个步骤中仅等待 `(1 - ratio) * batch_size` 的任务。默认为 `0.0`,表示等待所有任务。 - `wait_after_min`: 达到最小任务阈值后,等待此秒数后再继续。 - `dynamic_timeout`: [实验性] 动态超时机制的配置,根据成功任务的平均耗时调整每个任务的超时时间。 - `enable`: 是否启用动态超时。默认为 `false`。 - `ratio`: 每个任务的超时时间动态设置为 `average_time_per_success_task * ratio`。默认为 `3.0`。 +- `runner_state_report_interval`: WorkflowRunner 报告自身状态的时间间隔(秒)。若设为大于 0 的值,工作流执行器会定期将其状态报告给 explorer 主进程并打印在命令行中,以便监控其运行状态。默认为 `0`,表示不启用此功能。推荐如需使用此功能,将其设置为 `10` 秒或更长时间以减少对性能的影响。 --- diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index f4a69efa97..ae4222b538 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -1,6 +1,7 @@ import asyncio import time import unittest +from collections import defaultdict from typing import List, Optional import ray @@ -11,7 +12,7 @@ from trinity.common.config import ExperienceBufferConfig from trinity.common.constants import StorageType, SyncStyle from trinity.common.experience import EID, Experience -from trinity.common.models.model import InferenceModel +from trinity.common.models.model import InferenceModel, ModelWrapper from trinity.common.workflows import Task from trinity.common.workflows.workflow import WORKFLOWS, Workflow from trinity.explorer.scheduler import Scheduler @@ -134,6 +135,41 @@ def run(self): raise RuntimeError("This method should not be called") +@WORKFLOWS.register_module("dummy_workflow_with_state") +class DummyWorkflowWithState(Workflow): + can_repeat: bool = True + is_async: bool = True + + def __init__(self, *, task, model: ModelWrapper, auxiliary_models): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + self.step_num = task.workflow_args.get("step_num", 1) + + def set_repeat_times(self, repeat_times, run_id_base): + self.repeat_times = repeat_times + self.run_id_base = run_id_base + + async def run_async(self) -> List[Experience]: + exps = [] + for i in range(self.repeat_times): + run_level_metrics = {"run_metrics": float(i + self.run_id_base)} + run_level_exps = [] + for step in range(self.step_num): + run_level_exps.append( + Experience( + eid=EID(run=i + self.run_id_base, step=step), + tokens=torch.zeros(5), + prompt_length=2, + prompt_text="success", + ) + ) + run_level_exps[-1].metrics = run_level_metrics + self.logger.info(f"Setting workflow state to repeat_cnt={i}") + await self.model.set_workflow_state({"repeat_cnt": i}) + await asyncio.sleep(1) + exps.extend(run_level_exps) + return exps + + @ray.remote class DummyModel(InferenceModel): def sync_model(self, model_version, update_weight_args_list): @@ -779,3 +815,58 @@ def tearDown(self): ray.shutdown() except Exception: pass + + +class TestRunnerStateCollection(unittest.IsolatedAsyncioTestCase): + async def test_runner_state_collection(self): + ray.init(ignore_reinit_error=True) + config = get_template_config() + config.explorer.runner_per_model = 2 + config.explorer.runner_state_report_interval = 0.5 + config.explorer.max_repeat_times_per_runner = 2 + config.check_and_update() + scheduler = Scheduler(config, [DummyModel.remote(), DummyModel.remote()]) + # 4 runner in side the scheduler + await scheduler.start() + + tasks = [ + Task( + workflow=DummyWorkflowWithState, # type: ignore[type-abstract] + workflow_args={"step_num": 2}, + repeat_times=4, + raw_task={}, + ) + for _ in range(4) + ] + scheduler.schedule(tasks, batch_id=0) + + async def monitor_routine(): + runner_0_state_history = defaultdict(set) + await asyncio.sleep(0.5) # wait for first report + for _ in range(16): + await asyncio.sleep(0.3) + states = scheduler.get_all_state() + self.assertEqual(len(states), 4) + for state in states.values(): + self.assertIn("workflow_id", state) + self.assertIn("model_version", state) + self.assertIn("begin_time", state) + self.assertIn("terminate_time", state) + self.assertIn("repeat_cnt", state) + ids = scheduler.get_key_state("workflow_id") + self.assertEqual(len(ids), 4) + self.assertEqual(len(set(ids.values())), 4) + runner_0_state = scheduler.get_runner_state(0) + for key, value in runner_0_state.items(): + runner_0_state_history[key].add(value) + self.assertEqual(len(runner_0_state_history["repeat_cnt"]), 2) # max_repeat_times is 2 + self.assertEqual(len(runner_0_state_history["model_version"]), 1) + self.assertEqual( + len(runner_0_state_history["workflow_id"]), 2 + ) # split into 2 sub tasks + self.assertEqual(len(runner_0_state_history["begin_time"]), 2) + + await asyncio.gather( + monitor_routine(), + scheduler.get_results(batch_id=0), + ) diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 81e1afa916..e0a86da6f1 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -2,6 +2,7 @@ """Test for the workflow module""" import asyncio import unittest +from collections import defaultdict from dataclasses import dataclass, field from typing import Dict, Optional from unittest import mock @@ -508,6 +509,48 @@ def tearDown(self): ray.shutdown(_exiting_interpreter=True) +class StateRecordingWorkflow(Workflow): + is_async: bool = True + + def __init__(self, *, task, model: ModelWrapper, auxiliary_models): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + self.wait_time = task.workflow_args.get("wait_time", 1) + + async def run_async(self): + for i in range(self.wait_time): + await self.model.set_workflow_state({"step": i}) + await asyncio.sleep(1) + return [Experience(tokens=Tensor([0, 1, 2]), prompt_length=1, reward=1.0)] + + +class TestWorkflowStateRecording(unittest.IsolatedAsyncioTestCase): + async def test_workflow_state_recording(self): + model = MagicMock() + model_wrapper = ModelWrapper(model, engine_type="vllm") + + task = Task( + workflow=StateRecordingWorkflow, + repeat_times=3, + raw_task={}, + workflow_args={"wait_time": 3}, + ) + workflow = task.to_workflow(model_wrapper) + + async def monitor_routine(): + old_state = {} + count = 0 + for i in range(20): + await asyncio.sleep(0.2) + new_state = await model_wrapper.get_workflow_state() + if new_state.get("step") != old_state.get("step"): + old_state = new_state + count += 1 + self.assertEqual(count, 3) + return count + + await asyncio.gather(*[monitor_routine(), workflow.run_async()]) + + class TestAgentScopeWorkflowAdapter(unittest.IsolatedAsyncioTestCase): @unittest.skip("Waiting for agentscope>=0.1.6") async def test_adapter(self): @@ -559,6 +602,9 @@ def get_openai_client(self): def get_openai_async_client(self): return openai.AsyncOpenAI(api_key="EMPTY") + async def clean_workflow_state(self): + return + @property async def model_version_async(self): return 0 @@ -604,3 +650,50 @@ async def test_workflow_runner(self): self.assertTrue(status.ok) self.assertIsInstance(exps, list) self.assertEqual(len(exps), 2) + + async def test_workflow_runner_get_state(self): + config = get_template_config() + + async def mock_get_api_server_url_remote(): + return None + + async def mock_get_model_version_remote(): + return 1 + + model = MagicMock() + model.get_api_server_url.remote = MagicMock(side_effect=mock_get_api_server_url_remote) + model.get_model_version.remote = MagicMock(side_effect=mock_get_model_version_remote) + + runner = WorkflowRunner( + config, + model=model, + auxiliary_models=[], + runner_id=1, + ) + await runner.prepare() + + task = Task( + workflow=StateRecordingWorkflow, + raw_task={}, + workflow_args={"wait_time": 2}, + batch_id=1, + task_id=2, + ) + + async def monitor_routine(): + state_history = defaultdict(set) + count = 0 + for i in range(20): + await asyncio.sleep(0.4) + new_state = await runner.get_runner_state() + for k, v in new_state.items(): + state_history[k].add(v) + self.assertEqual(len(state_history["model_version"]), 1) + self.assertEqual(len(state_history["workflow_id"]), 3) + self.assertEqual(len(state_history["begin_time"]), 3) + self.assertEqual(len(state_history["step"]), 2) + return count + + await asyncio.gather( + *[monitor_routine(), runner.run_task(task, repeat_times=3, run_id_base=0)] + ) diff --git a/trinity/common/config.py b/trinity/common/config.py index bd2d6f6907..8dd622e569 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -670,6 +670,8 @@ class ExplorerConfig: # Experimental feature over_rollout: OverRolloutConfig = field(default_factory=OverRolloutConfig) dynamic_timeout: DynamicTimeoutConfig = field(default_factory=DynamicTimeoutConfig) + # report runner state every `runner_state_report_interval` seconds, 0 to disable + runner_state_report_interval: int = 0 @dataclass diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index a2ecdd90a4..bd0f36c99c 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -4,7 +4,7 @@ import socket from abc import ABC, abstractmethod from functools import partial -from typing import List, Optional, Sequence, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import httpx import numpy as np @@ -103,7 +103,9 @@ def __init__( self.enable_history = enable_history self.history = [] self.status = RunningStatus.RUNNING + self.workflow_state: Dict = {} self.request_count = 0 + self.state_lock = asyncio.Lock() async def prepare(self) -> None: """Prepare the model wrapper.""" @@ -361,6 +363,22 @@ def extract_experience_from_history(self, clear_history: bool = True) -> List[Ex self.history.clear() return exps + # Workflow state management methods + async def set_workflow_state(self, state: Dict) -> None: + """Set the state of workflow using the model.""" + async with self.state_lock: + self.workflow_state.update(state) + + async def clean_workflow_state(self) -> None: + """Clean the state of workflow using the model.""" + async with self.state_lock: + self.workflow_state = {} + + async def get_workflow_state(self) -> Dict: + """Get the state of workflow using the model.""" + async with self.state_lock: + return self.workflow_state.copy() + def convert_api_output_to_experience( output, diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 686b08f080..30ece621d3 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -380,6 +380,7 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None: metric.update(pipeline_metrics) if statuses: metric.update(gather_metrics([status.metrics[0] for status in statuses], "rollout")) + metric["rollout/finished_task_count"] = len(statuses) self.monitor.log(metric, step=step) async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None: @@ -392,10 +393,11 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva if eval_step != step: return self.pending_eval_tasks.popleft() - eval_results, _ = await self.scheduler.get_results(batch_id=f"{step}/{eval_task_name}") + statuses, _ = await self.scheduler.get_results(batch_id=f"{step}/{eval_task_name}") + metric[f"{prefix}/{eval_task_name}/finished_task_count"] = len(statuses) metric.update( gather_metrics( - [status.metrics[0] for status in eval_results], f"{prefix}/{eval_task_name}" + [status.metrics[0] for status in statuses], f"{prefix}/{eval_task_name}" ) ) if self.eval_start_time is not None: diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index ae2fdfe4a6..d86859dbc1 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -61,6 +61,7 @@ def __init__( self.timeout = config.explorer.max_timeout self.namespace = ray.get_runtime_context().namespace self.runner = self._create_runner() + self.state = {} def _create_runner(self): return ( @@ -79,6 +80,10 @@ def _create_runner(self): async def prepare(self): await self.runner.prepare.remote() + async def update_state(self) -> None: + """Get the runner state.""" + self.state = await self.runner.get_runner_state.remote() + async def run_with_retry( self, task: TaskWrapper, repeat_times: int, run_id_base: int, timeout: float ) -> Tuple[Status, List, int, float]: @@ -253,6 +258,24 @@ async def _scheduler_loop(self) -> None: await asyncio.sleep(0.1) self.logger.info("Scheduler loop stopped.") + async def _monitor_runner_state_loop(self) -> None: + interval = self.config.explorer.runner_state_report_interval + if interval <= 0: + self.logger.info("Runner state monitoring loop disabled.") + return + + self.logger.info("Runner state monitoring loop started.") + while self.running: + try: + await asyncio.gather(*[runner.update_state() for runner in self.runners.values()]) + self.print_all_state() + except Exception: + self.logger.error( + f"Error in runner state monitoring loop:\n{traceback.format_exc()}" + ) + await asyncio.sleep(interval) + self.logger.info("Runner state monitoring loop stopped.") + async def _schedule_pending_tasks(self) -> None: if not self.idle_runners: return @@ -333,6 +356,7 @@ async def start(self) -> None: self.scheduler_task = asyncio.create_task(self._scheduler_loop()) ready_refs = [runner.runner.__ray_ready__.remote() for runner in self.runners.values()] await asyncio.gather(*ready_refs) + self.monitor_task = asyncio.create_task(self._monitor_runner_state_loop()) self.logger.info(f"Starting Scheduler with {self.runner_num} runners") async def stop(self) -> None: @@ -354,6 +378,12 @@ async def stop(self) -> None: await self.scheduler_task except asyncio.CancelledError: pass + if self.monitor_task: + self.monitor_task.cancel() + try: + await self.monitor_task + except asyncio.CancelledError: + pass self.logger.info("Scheduler stopped") def schedule(self, tasks: List[Task], batch_id: Union[int, str]) -> None: @@ -525,3 +555,77 @@ async def wait_all( ) raise TimeoutError(error_msg) + + def get_key_state(self, key: str) -> Dict: + """Get the scheduler state. + + Args: + key (`str`): The key of the state to get. + + Returns: + `Dict`: A dictionary of runner ids to their state for the given key. + """ + result = {} + for runner in self.runners.values(): + runner_state = runner.state + if runner_state and key in runner_state: + result[runner.runner_id] = runner_state[key] + return result + + def get_runner_state(self, runner_id: int) -> Dict: + """Get the scheduler state. + + Args: + runner_id (`int`): The id of the runner. + + Returns: + `Dict`: The state of the runner. + """ + runner = self.runners.get(runner_id, None) + if runner: + return runner.state + else: + return {} + + def get_all_state(self) -> Dict: + """Get all runners' state. + + Returns: + `Dict`: The state of all runners. + """ + result = {} + for runner in self.runners.values(): + runner_state = runner.state + if runner_state: + result[runner.runner_id] = runner_state + return result + + def print_all_state(self) -> None: + """Print all runners' state in a clear, aligned table format.""" + all_keys = set() + for runner in self.runners.values(): + runner_state = runner.state + if runner_state: + all_keys.update(runner_state.keys()) + all_keys = sorted(all_keys) + # Prepare header + header = ["runner_id"] + all_keys # type: ignore [operator] + # Prepare rows + rows = [] + for runner in self.runners.values(): + runner_state = runner.state or {} + row = [str(runner.runner_id)] + for key in all_keys: + value = runner_state.get(key, "-") + row.append(str(value)) + rows.append(row) + # Calculate column widths + col_widths = [max(len(str(x)) for x in col) for col in zip(header, *rows)] + # Print header + header_line = " | ".join(str(h).ljust(w) for h, w in zip(header, col_widths)) + self.logger.info(header_line) + self.logger.info("-+-".join("-" * w for w in col_widths)) + # Print each row + for row in rows: + line = " | ".join(str(cell).ljust(w) for cell, w in zip(row, col_widths)) + self.logger.info(line) diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 85af23aa1b..314f8b94e0 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -74,6 +74,12 @@ def __init__( self.auxiliary_model_async_clients = [] self.workflow_instance: Workflow = None self.runner_id = runner_id + self.runner_state = { + "workflow_id": None, + "model_version": None, + "begin_time": 0, + "terminate_time": 0, + } async def prepare(self) -> None: """Prepare the runner.""" @@ -121,23 +127,35 @@ async def _run_task( ) -> Tuple[List[Experience], List[Dict]]: """Init workflow from the task and run it.""" self._create_workflow_instance(task) + if self.workflow_instance.repeatable: self.workflow_instance.set_repeat_times(repeat_times, run_id_base) st = time.time() + await self.model_wrapper.clean_workflow_state() + self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{run_id_base}" + self.runner_state["terminate_time"] = None + self.runner_state["begin_time"] = st exps = await self._run_workflow(self.workflow_instance) - task_execution_time = time.time() - st + et = time.time() + self.runner_state["terminate_time"] = et # repeatable workflow cannot calculate run level metrics, we use experience level metrics directly run_metrics = [exp.metrics for exp in exps if exp.metrics] for metric in run_metrics: - metric["time/task_execution"] = task_execution_time + metric["time/task_execution"] = et - st else: exps = [] run_metrics = [] for i in range(repeat_times): st = time.time() + await self.model_wrapper.clean_workflow_state() + self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{i}" + self.runner_state["terminate_time"] = None + self.runner_state["begin_time"] = st new_exps = await self._run_workflow(self.workflow_instance) + et = time.time() + self.runner_state["terminate_time"] = et run_metric = calculate_run_level_metrics(new_exps) - run_metric["time/task_execution"] = time.time() - st + run_metric["time/task_execution"] = et - st run_metrics.append(run_metric) for exp in new_exps: exp.eid.run = run_id_base + i @@ -146,6 +164,12 @@ async def _run_task( self._create_workflow_instance(task) return exps, run_metrics + async def get_runner_state(self) -> Dict: + """Get the runner state.""" + runner_state = self.runner_state.copy() + runner_state.update(await self.model_wrapper.get_workflow_state()) + return runner_state + async def run_task( self, task: Task, @@ -156,9 +180,10 @@ async def run_task( # TODO: avoid sending the experiences back to the scheduler to reduce the communication overhead try: st = time.time() + model_version = await self.model_wrapper.model_version_async + self.runner_state["model_version"] = model_version exps, metrics = await self._run_task(task, repeat_times, run_id_base) assert exps is not None and len(exps) > 0, "An empty experience is generated" - model_version = await self.model_wrapper.model_version_async # set eid for each experience for exp in exps: exp.eid.batch = task.batch_id diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 3e508c9124..dd977206c8 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -464,7 +464,7 @@ def _build_model_optimizer( # noqa: C901 num_warmup_steps = int(num_warmup_steps_ratio * total_steps) if self.rank == 0: - print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") + print(f"num_warmup_steps: {num_warmup_steps}") if warmup_style == "constant": actor_lr_scheduler = get_constant_schedule_with_warmup( @@ -1180,7 +1180,7 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 num_warmup_steps = int(num_warmup_steps_ratio * total_steps) if self.rank == 0: - print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}") + print(f"num_warmup_steps: {num_warmup_steps}") from verl.utils.torch_functional import ( get_constant_schedule_with_warmup,