@@ -44,11 +44,8 @@ def _merge_batch(self, dp_batches: List[Batch]):
4444
4545 def append (self , req : Req ):
4646 suggested_dp_index = req .sample_params .suggested_dp_index
47- if suggested_dp_index is None or suggested_dp_index >= self .dp_size_in_node or suggested_dp_index < 0 :
48- if suggested_dp_index is not None and (
49- suggested_dp_index >= self .dp_size_in_node or suggested_dp_index < 0
50- ):
51- logger .error (f"input req { req .request_id } dp index { suggested_dp_index } has error" )
47+ if suggested_dp_index >= self .dp_size_in_node or suggested_dp_index < 0 :
48+ logger .error (f"input req { req .request_id } dp index { suggested_dp_index } has error" )
5249 suggested_dp_index = random .randint (0 , self .dp_size_in_node - 1 )
5350 req .sample_params .suggested_dp_index = suggested_dp_index
5451 self .inner_queues [suggested_dp_index ].append (req )
@@ -61,11 +58,8 @@ def extend(self, req_group: List[Req]):
6158 index = random .randint (0 , self .dp_size_in_node - 1 )
6259 for req in req_group :
6360 suggested_dp_index = req .sample_params .suggested_dp_index
64- if suggested_dp_index is None or suggested_dp_index >= self .dp_size_in_node or suggested_dp_index < 0 :
65- if suggested_dp_index is not None and (
66- suggested_dp_index >= self .dp_size_in_node or suggested_dp_index < 0
67- ):
68- logger .error (f"input req { req .request_id } dp index { suggested_dp_index } has error" )
61+ if suggested_dp_index >= self .dp_size_in_node or suggested_dp_index < 0 :
62+ logger .error (f"input req { req .request_id } dp index { suggested_dp_index } has error" )
6963 req .sample_params .suggested_dp_index = index
7064 self .inner_queues [index ].append (req )
7165 else :
0 commit comments