@@ -113,7 +113,6 @@ def loop(self) -> None:
113
113
) as pbar :
114
114
for step in pbar :
115
115
i = 0
116
- allow_sync_model = False
117
116
for _ in range (self .num_recv_per_update ):
118
117
# receive data from producers
119
118
for r in range (self .num_producers ):
@@ -139,7 +138,6 @@ def loop(self) -> None:
139
138
else :
140
139
self .buffer = self .buffer [self .dp_size * self .minibatch_size :]
141
140
if loss is not None :
142
- allow_sync_model = True
143
141
pbar .set_postfix ({"loss" : loss })
144
142
i += 1
145
143
if self .lr_scheduler is not None :
@@ -153,31 +151,29 @@ def loop(self) -> None:
153
151
print (f"Saved model checkpoint at step { step + 1 } in folder { save_path } " )
154
152
155
153
if episode != self .num_episodes - 1 or step != self .num_update_per_episode - 1 :
156
- if allow_sync_model :
157
- if self .pp_size > 1 :
158
- print (
159
- f"[T{ dist .get_rank ()} ] Sync model PP stage { self .pp_rank } episode { episode } step { step } "
154
+ if self .pp_size > 1 :
155
+ print (
156
+ f"[T{ dist .get_rank ()} ] Sync model PP stage { self .pp_rank } episode { episode } step { step } "
157
+ )
158
+ else :
159
+ print (f"[T{ dist .get_rank ()} ] Sync model episode { episode } step { step } " )
160
+ torch .cuda .empty_cache ()
161
+ state_dict = self .state_dict ()
162
+ if self .pp_size > 1 :
163
+ if self .tp_rank == 0 and self .dp_rank == 0 :
164
+ ray_broadcast_tensor_dict (
165
+ state_dict ,
166
+ src = self .num_producers ,
167
+ device = self .device ,
168
+ group_name = f"sync_model_{ self .pp_rank } " ,
160
169
)
161
- else :
162
- print (f"[T{ dist .get_rank ()} ] Sync model episode { episode } step { step } " )
163
- torch .cuda .empty_cache ()
164
- state_dict = self .state_dict ()
165
- if self .pp_size > 1 :
166
- if self .tp_rank == 0 and self .dp_rank == 0 :
167
- ray_broadcast_tensor_dict (
168
- state_dict ,
169
- src = self .num_producers ,
170
- device = self .device ,
171
- group_name = f"sync_model_{ self .pp_rank } " ,
172
- )
173
- else :
174
- if self .rank == 0 :
175
- ray_broadcast_tensor_dict (
176
- state_dict , src = self .num_producers , device = self .device , group_name = "sync_model"
177
- )
178
- del state_dict
179
- torch .cuda .empty_cache ()
180
- allow_sync_model = False
170
+ else :
171
+ if self .rank == 0 :
172
+ ray_broadcast_tensor_dict (
173
+ state_dict , src = self .num_producers , device = self .device , group_name = "sync_model"
174
+ )
175
+ del state_dict
176
+ torch .cuda .empty_cache ()
181
177
182
178
183
179
@ray .remote
0 commit comments