Skip to content

Commit ce78ce4

Browse files
committed
optimize logic
Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
1 parent 9989e83 commit ce78ce4

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

verl/utils/transferqueue_utils.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,15 @@ def _compute_need_collect(dispatch_mode: dict, args: list) -> bool:
159159
return args[0]._Worker__collect_dp_rank[collect_mesh_name]
160160

161161

162+
def _postprocess_common(output, put_data, need_collect):
163+
if put_data and not need_collect:
164+
return BatchMeta.empty()
165+
elif not put_data and not need_collect and isinstance(output, DataProto):
166+
return DataProto()
167+
else:
168+
return output
169+
170+
162171
def tqbridge(dispatch_mode=None, put_data: bool = True):
163172
"""Creates a decorator for bridging BatchMeta and DataProto.
164173
@@ -199,13 +208,11 @@ def inner(*args, **kwargs):
199208
kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()}
200209
output = func(*args, **kwargs)
201210
need_collect = _compute_need_collect(dispatch_mode, args)
211+
updated_batch_meta = None
202212
if put_data and need_collect:
203213
updated_batch_meta = _update_batchmeta_with_output(output, batchmeta, func.__name__)
204214
return updated_batch_meta
205-
elif not need_collect:
206-
return BatchMeta.empty()
207-
else:
208-
return output
215+
return _postprocess_common(output, put_data, need_collect, updated_batch_meta)
209216

210217
@wraps(func)
211218
async def async_inner(*args, **kwargs):
@@ -224,21 +231,27 @@ async def async_inner(*args, **kwargs):
224231
}
225232
output = await func(*args, **kwargs)
226233
need_collect = _compute_need_collect(dispatch_mode, args)
234+
updated_batchmeta = None
227235
if put_data and need_collect:
228236
updated_batchmeta = await _async_update_batchmeta_with_output(output, batchmeta, func.__name__)
229237
return updated_batchmeta
230-
elif not need_collect:
231-
return BatchMeta.empty()
232-
else:
233-
return output
238+
return _postprocess_common(output, put_data, need_collect)
234239

235240
@wraps(func)
236241
def dummy_inner(*args, **kwargs):
237-
return func(*args, **kwargs)
242+
output = func(*args, **kwargs)
243+
need_collect = _compute_need_collect(dispatch_mode, args)
244+
if not need_collect:
245+
return DataProto()
246+
return output
238247

239248
@wraps(func)
240249
async def dummy_async_inner(*args, **kwargs):
241-
return await func(*args, **kwargs)
250+
output = await func(*args, **kwargs)
251+
need_collect = _compute_need_collect(dispatch_mode, args)
252+
if not need_collect:
253+
return DataProto()
254+
return output
242255

243256
wrapper_inner = inner if is_transferqueue_enabled else dummy_inner
244257
wrapper_async_inner = async_inner if is_transferqueue_enabled else dummy_async_inner

0 commit comments

Comments
 (0)