We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4702d57 commit 45ac6c6Copy full SHA for 45ac6c6
applications/ColossalChat/coati/distributed/producer.py
@@ -154,7 +154,11 @@ def __init__(
154
155
@torch.no_grad()
156
def rollout(self, input_ids, attention_mask, **kwargs):
157
- return self.model.generate(input_ids, attention_mask, **kwargs)
+ rollouts = self.model.generate(input_ids, attention_mask, **kwargs)
158
+ if self.producer_idx == 1:
159
+ print("Rollout example:\n", self.tokenizer.decode(rollouts["input_ids"][0][0], skip_special_tokens=True))
160
+
161
+ return rollouts
162
163
def load_state_dict(self, state_dict):
164
self.model.load_state_dict(state_dict)
0 commit comments