Skip to content

Commit 77c9a0e

Browse files
committed
apply review suggestions
1 parent a60ade4 commit 77c9a0e

File tree

3 files changed

+30
-12
lines changed

3 files changed

+30
-12
lines changed

verl/single_controller/base/decorator.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,7 @@
1818

1919
from verl.protocol import DataProtoFuture, _padding_size_key
2020
from verl.utils.py_functional import DynamicEnum
21-
22-
# TODO: Use a hacky workaround for ImportError since
23-
# transfer_queue isn't a default verl dependency.
24-
try:
25-
from transfer_queue import BatchMeta
26-
except ImportError:
27-
28-
class BatchMeta:
29-
pass
30-
21+
from verl.utils.transferqueue_utils import BatchMeta
3122

3223
# here we add a magic number of avoid user-defined function already have this attribute
3324
MAGIC_ATTR = "attrs_3141562937"

verl/trainer/main_ppo.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def run_ppo(config, task_runner_class=None) -> None:
6767
ray.init(**OmegaConf.to_container(ray_init_kwargs))
6868

6969
if task_runner_class is None:
70-
task_runner_class = TaskRunner
70+
task_runner_class = ray.remote(TaskRunner).options(
71+
num_cpus=1
72+
) # please make sure main_task is not scheduled on head
7173

7274
# Create a remote instance of the TaskRunner class, and
7375
# Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
@@ -95,7 +97,6 @@ def run_ppo(config, task_runner_class=None) -> None:
9597
ray.timeline(filename=timeline_json_file)
9698

9799

98-
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
99100
class TaskRunner:
100101
"""Ray remote class for executing distributed PPO training tasks.
101102

verl/utils/transferqueue_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@
3131
except ImportError:
3232
HAS_TQ = False
3333

34+
# TODO: Use a hacky workaround for ImportError since
35+
# transfer_queue isn't a default verl dependency.
36+
class BatchMeta:
37+
pass
38+
39+
3440
from verl.protocol import DataProto
3541

3642
_TRANSFER_QUEUE_CLIENT = None
@@ -135,6 +141,26 @@ def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta") ->
135141

136142

137143
def tqbridge(put_data: bool = True):
144+
""" "Creates a decorator for bridging BatchMeta and DataProto.
145+
146+
This decorator automatically handles conversions between `BatchMeta` and
147+
`DataProto` in function parameters, and decides whether to sync function
148+
output back to `BatchMeta` based on configuration(`put_data`). It supports
149+
both synchronous and asynchronous functions (async def), and can control
150+
whether to enable enhanced logic via the global `HAS_TQ` variable (when disabled,
151+
simply calls the original function as-is).
152+
153+
Args:
154+
put_data: Whether put the DataProto into Storage after func return.
155+
If True, after function execution, the output result will be
156+
updated to `BatchMeta` and `BatchMeta` will be returned;
157+
If False, the function output result will be returned directly.
158+
Defaults to True.
159+
160+
Returns:
161+
A decorator function used to decorate target functions (synchronous or asynchronous).
162+
"""
163+
138164
def decorator(func):
139165
@wraps(func)
140166
def inner(*args, **kwargs):

0 commit comments

Comments
 (0)