File tree Expand file tree Collapse file tree 2 files changed +14
-8
lines changed
Expand file tree Collapse file tree 2 files changed +14
-8
lines changed Original file line number Diff line number Diff line change @@ -67,19 +67,25 @@ def is_clear(self):
6767 return len (self .reqs ) == 0
6868
6969 def merge (self , mini_batch : "Batch" ):
70- for _req in mini_batch .reqs :
71- self .reqs .append (_req )
72- self .id_to_reqs = {req .request_id : req for req in self .reqs }
73- return
74-
75- def dp_merge (self , mini_batch : "Batch" ):
7670 if mini_batch is None :
7771 return
78-
72+
7973 for _req in mini_batch .reqs :
8074 self .reqs .append (_req )
8175 self .id_to_reqs = {req .request_id : req for req in self .reqs }
8276 return
77+
78+ @staticmethod
79+ def merge_two_batch (batch1 : "Batch" , batch2 : "Batch" ):
80+ if batch1 is None and batch2 is None :
81+ return None
82+
83+ not_none_batch = batch1 if batch1 is not None else batch2
84+
85+ merge_batch = Batch (- 1 , [], not_none_batch .dp_size_in_node )
86+ merge_batch .merge (batch1 )
87+ merge_batch .merge (batch2 )
88+ return
8389
8490 def __repr__ (self ):
8591 return f"batch_id={ self .batch_id } , " f"reqs={ self .reqs } , "
Original file line number Diff line number Diff line change @@ -43,7 +43,7 @@ def _merge_batch(self, dp_batches: List[Batch]):
4343 merged_batch : Batch = None
4444 for iter_batch in dp_batches :
4545 if merged_batch is not None :
46- merged_batch .dp_merge (iter_batch )
46+ merged_batch .merge (iter_batch )
4747 else :
4848 merged_batch = iter_batch
4949 return merged_batch
You can’t perform that action at this time.
0 commit comments