@@ -114,7 +114,6 @@ def loop(self) -> None:
114
114
) as pbar :
115
115
for step in pbar :
116
116
i = 0
117
- allow_sync_model = False
118
117
for _ in range (self .num_recv_per_update ):
119
118
# receive data from producers
120
119
for r in range (self .num_producers ):
@@ -140,7 +139,6 @@ def loop(self) -> None:
140
139
else :
141
140
self .buffer = self .buffer [self .dp_size * self .minibatch_size :]
142
141
if loss is not None :
143
- allow_sync_model = True
144
142
pbar .set_postfix ({"loss" : loss })
145
143
i += 1
146
144
if self .lr_scheduler is not None :
@@ -154,31 +152,29 @@ def loop(self) -> None:
154
152
print (f"Saved model checkpoint at step { step + 1 } in folder { save_path } " )
155
153
156
154
if episode != self .num_episodes - 1 or step != self .num_update_per_episode - 1 :
157
- if allow_sync_model :
158
- if self .pp_size > 1 :
159
- print (
160
- f"[T{ dist .get_rank ()} ] Sync model PP stage { self .pp_rank } episode { episode } step { step } "
155
+ if self .pp_size > 1 :
156
+ print (
157
+ f"[T{ dist .get_rank ()} ] Sync model PP stage { self .pp_rank } episode { episode } step { step } "
158
+ )
159
+ else :
160
+ print (f"[T{ dist .get_rank ()} ] Sync model episode { episode } step { step } " )
161
+ torch .cuda .empty_cache ()
162
+ state_dict = self .state_dict ()
163
+ if self .pp_size > 1 :
164
+ if self .tp_rank == 0 and self .dp_rank == 0 :
165
+ ray_broadcast_tensor_dict (
166
+ state_dict ,
167
+ src = self .num_producers ,
168
+ device = self .device ,
169
+ group_name = f"sync_model_{ self .pp_rank } " ,
161
170
)
162
- else :
163
- print (f"[T{ dist .get_rank ()} ] Sync model episode { episode } step { step } " )
164
- torch .cuda .empty_cache ()
165
- state_dict = self .state_dict ()
166
- if self .pp_size > 1 :
167
- if self .tp_rank == 0 and self .dp_rank == 0 :
168
- ray_broadcast_tensor_dict (
169
- state_dict ,
170
- src = self .num_producers ,
171
- device = self .device ,
172
- group_name = f"sync_model_{ self .pp_rank } " ,
173
- )
174
- else :
175
- if self .rank == 0 :
176
- ray_broadcast_tensor_dict (
177
- state_dict , src = self .num_producers , device = self .device , group_name = "sync_model"
178
- )
179
- del state_dict
180
- torch .cuda .empty_cache ()
181
- allow_sync_model = False
171
+ else :
172
+ if self .rank == 0 :
173
+ ray_broadcast_tensor_dict (
174
+ state_dict , src = self .num_producers , device = self .device , group_name = "sync_model"
175
+ )
176
+ del state_dict
177
+ torch .cuda .empty_cache ()
182
178
183
179
184
180
@ray .remote
0 commit comments