@@ -129,7 +129,7 @@ def __init__(
129
129
else :
130
130
raise ValueError (f"Unexpected backend { backend } " )
131
131
132
- self .consumer_pp_size = consumer_plugin_config [ "pp_size" ] # consumer pp size
132
+ self .consumer_pp_size = consumer_plugin_config . get ( "pp_size" , 1 ) # consumer pp size
133
133
134
134
def setup (self ) -> None :
135
135
cc .init_collective_group (1 + self .num_consumer_procs , 0 , group_name = f"sync_data_{ self .producer_idx } " )
@@ -250,14 +250,11 @@ def loop(self) -> None:
250
250
# linear annealing for 1 episode, temperature from initial to 0.9
251
251
if episode <= 0 :
252
252
ratio = 1 - (len (self .train_dataloader ) - i ) / len (self .train_dataloader )
253
- if isinstance (self .model .generate_config .temperature , dict ):
254
- self .model .generate_config ["temperature" ] = (1 - ratio ) * self .generate_config [
255
- "temperature"
256
- ] + ratio * 0.9
257
- else :
258
- self .model .generate_config .temperature = (1 - ratio ) * self .generate_config [
259
- "temperature"
260
- ] + ratio * 0.9
253
+ self .model .generate_config ["temperature" ] = (1 - ratio ) * self .generate_config [
254
+ "temperature"
255
+ ] + ratio * 0.9
256
+ if hasattr (self .model , "sample_params" ):
257
+ self .model .sample_params .temperature = self .model .generate_config ["temperature" ]
261
258
262
259
263
260
@ray .remote
@@ -310,8 +307,8 @@ def __init__(
310
307
@torch .no_grad ()
311
308
def rollout (self , input_ids , attention_mask , ** kwargs ):
312
309
rollouts = self .model .generate (input_ids , attention_mask , ** kwargs )
313
- # if self.producer_idx == 1:
314
- # print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
310
+ if self .producer_idx == 1 :
311
+ print ("Rollout example:\n " , self .tokenizer .decode (rollouts ["input_ids" ][0 ][0 ], skip_special_tokens = True ))
315
312
316
313
return rollouts
317
314
0 commit comments