@@ -29,6 +29,7 @@ def __init__(
29
29
tokenizer_config : Optional [Dict [str , Any ]] = None ,
30
30
microbatch_size : int = 1 ,
31
31
backend : str = "transformers" ,
32
+ consumer_plugin_config : Dict [str , Any ] = None ,
32
33
):
33
34
self .producer_idx = producer_idx
34
35
self .num_producers = num_producers
@@ -78,9 +79,15 @@ def __init__(
78
79
else :
79
80
raise ValueError (f"Unexpected backend { backend } " )
80
81
82
+ self .consumer_pp_size = consumer_plugin_config ["pp_size" ] # consumer pp size
83
+
81
84
def setup (self ) -> None :
82
85
cc .init_collective_group (1 + self .num_consumer_procs , 0 , group_name = f"sync_data_{ self .producer_idx } " )
83
- cc .init_collective_group (self .num_producers + 1 , self .producer_idx , group_name = "sync_model" )
86
+ if self .consumer_pp_size > 1 :
87
+ for i in range (self .consumer_pp_size ):
88
+ cc .init_collective_group (self .num_producers + 1 , self .producer_idx , group_name = f"sync_model_{ i } " )
89
+ else :
90
+ cc .init_collective_group (self .num_producers + 1 , self .producer_idx , group_name = "sync_model" )
84
91
85
92
def rollout (self , input_ids : torch .Tensor , attention_mask : torch .Tensor , ** kwargs ) -> Dict [str , torch .Tensor ]:
86
93
raise NotImplementedError
@@ -125,15 +132,25 @@ def loop(self) -> None:
125
132
):
126
133
self .model .llm .sleep () # revict KV_cache to avoid OOM
127
134
# don't sync model for last iteration
128
- print (
129
- f"[P{ self .producer_idx } ] Sync model episode { episode } step { (i + 1 ) // self .num_microbatches - 1 } "
130
- )
131
135
torch .cuda .empty_cache ()
132
136
133
- state_dict = ray_broadcast_tensor_dict (
134
- None , self .num_producers , device = self .device , group_name = "sync_model"
135
- )
136
- self .load_state_dict (state_dict )
137
+ if self .consumer_pp_size > 1 :
138
+ for pp_idx in range (self .consumer_pp_size ):
139
+ print (
140
+ f"[P{ self .producer_idx } ] Sync model PP stage { pp_idx } episode { episode } step { (i + 1 ) // self .num_microbatches - 1 } "
141
+ )
142
+ state_dict = ray_broadcast_tensor_dict (
143
+ None , self .num_producers , device = self .device , group_name = f"sync_model_{ pp_idx } "
144
+ )
145
+ self .load_state_dict (state_dict )
146
+ else :
147
+ print (
148
+ f"[P{ self .producer_idx } ] Sync model episode { episode } step { (i + 1 ) // self .num_microbatches - 1 } "
149
+ )
150
+ state_dict = ray_broadcast_tensor_dict (
151
+ None , self .num_producers , device = self .device , group_name = "sync_model"
152
+ )
153
+ self .load_state_dict (state_dict )
137
154
del state_dict
138
155
torch .cuda .empty_cache ()
139
156
if isinstance (self .model , BACKEND_MAP ["vllm" ]) and self .model .model_config .get (
@@ -170,6 +187,7 @@ def __init__(
170
187
microbatch_size = 1 ,
171
188
backend = "transformers" ,
172
189
num_generations : int = 8 ,
190
+ consumer_plugin_config = None ,
173
191
):
174
192
super ().__init__ (
175
193
producer_idx ,
@@ -184,6 +202,7 @@ def __init__(
184
202
tokenizer_config ,
185
203
microbatch_size ,
186
204
backend ,
205
+ consumer_plugin_config ,
187
206
)
188
207
self .model = self .backend_cls (model_config , generate_config , self .tokenizer , num_generations )
189
208
0 commit comments