@@ -159,6 +159,15 @@ def _compute_need_collect(dispatch_mode: dict, args: list) -> bool:
159159 return args [0 ]._Worker__collect_dp_rank [collect_mesh_name ]
160160
161161
162+ def _postprocess_common (output , put_data , need_collect ):
163+ if put_data and not need_collect :
164+ return BatchMeta .empty ()
165+ elif not put_data and not need_collect and isinstance (output , DataProto ):
166+ return DataProto ()
167+ else :
168+ return output
169+
170+
162171def tqbridge (dispatch_mode = None , put_data : bool = True ):
163172 """Creates a decorator for bridging BatchMeta and DataProto.
164173
@@ -199,13 +208,11 @@ def inner(*args, **kwargs):
199208 kwargs = {k : _batchmeta_to_dataproto (v ) if isinstance (v , BatchMeta ) else v for k , v in kwargs .items ()}
200209 output = func (* args , ** kwargs )
201210 need_collect = _compute_need_collect (dispatch_mode , args )
211+ updated_batch_meta = None
202212 if put_data and need_collect :
203213 updated_batch_meta = _update_batchmeta_with_output (output , batchmeta , func .__name__ )
204214 return updated_batch_meta
205- elif not need_collect :
206- return BatchMeta .empty ()
207- else :
208- return output
215+ return _postprocess_common (output , put_data , need_collect , updated_batch_meta )
209216
210217 @wraps (func )
211218 async def async_inner (* args , ** kwargs ):
@@ -224,21 +231,27 @@ async def async_inner(*args, **kwargs):
224231 }
225232 output = await func (* args , ** kwargs )
226233 need_collect = _compute_need_collect (dispatch_mode , args )
234+ updated_batchmeta = None
227235 if put_data and need_collect :
228236 updated_batchmeta = await _async_update_batchmeta_with_output (output , batchmeta , func .__name__ )
229237 return updated_batchmeta
230- elif not need_collect :
231- return BatchMeta .empty ()
232- else :
233- return output
238+ return _postprocess_common (output , put_data , need_collect )
234239
235240 @wraps (func )
236241 def dummy_inner (* args , ** kwargs ):
237- return func (* args , ** kwargs )
242+ output = func (* args , ** kwargs )
243+ need_collect = _compute_need_collect (dispatch_mode , args )
244+ if not need_collect :
245+ return DataProto ()
246+ return output
238247
239248 @wraps (func )
240249 async def dummy_async_inner (* args , ** kwargs ):
241- return await func (* args , ** kwargs )
250+ output = await func (* args , ** kwargs )
251+ need_collect = _compute_need_collect (dispatch_mode , args )
252+ if not need_collect :
253+ return DataProto ()
254+ return output
242255
243256 wrapper_inner = inner if is_transferqueue_enabled else dummy_inner
244257 wrapper_async_inner = async_inner if is_transferqueue_enabled else dummy_async_inner
0 commit comments