@@ -94,9 +94,6 @@ def setup(self) -> None:
94
94
if self .rank == 0 :
95
95
cc .init_collective_group (self .num_producers + 1 , self .num_producers , group_name = "sync_model" )
96
96
97
- for i in range (self .num_producers ):
98
- cc .init_collective_group (self .world_size + 1 , self .rank + 1 , group_name = f"sync_eval_statistics_{ i } " )
99
-
100
97
self .buffer = []
101
98
self .recv_cnt = 0
102
99
@@ -116,11 +113,14 @@ def loop(self) -> None:
116
113
i = 0
117
114
if self .eval_interval > 0 and step % self .eval_interval == 0 :
118
115
eval_statistics = None
116
+ eval_global_step = None
119
117
for r in range (self .num_producers ):
120
118
print (f"[T{ dist .get_rank ()} ] Recv eval result episode { episode } step { step } from { r } " )
121
119
local_eval_result = ray_broadcast_tensor_dict (
122
- None , src = 0 , device = self .device , group_name = f"sync_eval_statistics_ { r } "
120
+ None , src = 0 , device = self .device , group_name = f"sync_data_ { r } "
123
121
)
122
+ assert "consumer_global_step" in local_eval_result
123
+ eval_global_step = local_eval_result .pop ("consumer_global_step" ).item ()
124
124
if eval_statistics is None :
125
125
eval_statistics = local_eval_result
126
126
else :
@@ -129,8 +129,8 @@ def loop(self) -> None:
129
129
}
130
130
eval_statistics = {k : (v [0 ] / v [1 ]).item () for k , v in eval_statistics .items ()}
131
131
if dist .get_rank () == 0 :
132
- if hasattr (self , "wandb_run" ) and hasattr ( self , "global_step" ) :
133
- self .wandb_run .log (eval_statistics , step = self . global_step )
132
+ if hasattr (self , "wandb_run" ):
133
+ self .wandb_run .log (eval_statistics , step = eval_global_step )
134
134
print (f"Eval statistics: { eval_statistics } " )
135
135
for _ in range (self .num_recv_per_update ):
136
136
# receive data from producers
0 commit comments