Skip to content

Commit 8745e8f

Browse files
committed
test asyncllm producer and other settings
1 parent 2b46ab1 commit 8745e8f

File tree

4 files changed

+15
-9
lines changed

4 files changed

+15
-9
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ def loop(self) -> None:
181181
for step in pbar:
182182
torch.cuda.reset_peak_memory_stats()
183183
i = 0
184-
185184
self.profiler.enter(f"rollout_episode_{episode}_step_{step}")
186185
for _ in range(self.num_recv_per_update):
187186
if self.n_behind > 0:
@@ -325,6 +324,7 @@ def loop(self) -> None:
325324
) # for setting start index when resuming training
326325
if self.rank == 0:
327326
print(f"Saved model checkpoint at step {step + 1} in folder {self.save_dir}")
327+
328328
if (episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1) and (
329329
episode != 0 or step >= self.n_behind
330330
):

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,12 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
251251
micro_batch_input_ids_no_padding = [
252252
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
253253
]
254-
sample_params = kwargs.get("sample_params", self.sample_params)
254+
sample_params = self.sample_params
255+
if len(kwargs) > 0:
256+
sample_params = self.generate_config.copy()
257+
sample_params.update({k: v for k, v in kwargs.items() if k not in ["gt_answer", "test_cases", "labels"]})
258+
sample_params.update(self.FORCE_GENERATE_CONFIG)
259+
sample_params = SamplingParams(**sample_params)
255260
outputs = self.llm.generate(
256261
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False
257262
)
@@ -358,7 +363,7 @@ async def generate(
358363
input_ids (torch.Tensor): shape [B, S], B=1
359364
attention_mask (torch.Tensor): shape [B, S]
360365
"""
361-
assert input_ids.size(0) == attention_mask.size(0) == 1
366+
assert input_ids.size(0) == attention_mask.size(0) == 1, "AsyncVLLMInferenceBackend only supports batch size 1"
362367
request_id = (
363368
str(uuid4()) if not "request_id" in kwargs else kwargs.pop("request_id")
364369
) # use fixed request_id to reuse kv cache
@@ -368,7 +373,7 @@ async def generate(
368373
sample_params = self.sample_params
369374
if len(kwargs) > 0:
370375
sample_params = self.generate_config.copy()
371-
sample_params.update(kwargs)
376+
sample_params.update({k: v for k, v in kwargs.items() if k not in ["gt_answer", "test_cases", "labels"]})
372377
sample_params.update(self.FORCE_GENERATE_CONFIG)
373378
sample_params = SamplingParams(**sample_params)
374379
out_tokens = []

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def launch_distributed(
143143
tokenizer_config=tokenizer_config,
144144
microbatch_size=(
145145
inference_microbatch_size * num_generations
146-
if "async" in inference_backend
146+
if "async-agentic" in inference_backend
147147
else inference_microbatch_size
148148
),
149149
backend=inference_backend,

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,6 @@ def sync_data(self, data: Dict[str, torch.Tensor]) -> None:
284284
ray_broadcast_tensor_dict(data, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}")
285285

286286
def loop(self) -> None:
287-
# breakpoint()
288287
self.sync_model(0, 0)
289288
num_update_per_episode = len(self.train_dataloader) // self.num_microbatches
290289
num_valid_microbatches = num_update_per_episode * self.num_microbatches
@@ -620,10 +619,10 @@ async def generate(self, input_ids, attention_mask, **kwargs):
620619
rollouts = await asyncio.gather(*tasks)
621620
rollouts = {
622621
k: (
623-
torch.cat([r[k] for r in rollouts], dim=0)
622+
torch.cat([r[k] for r in rollouts], dim=0).cpu()
624623
if k not in ["gt_answer", "test_cases"]
625624
else [r[k] for r in rollouts]
626-
).cpu() # CUDA tensor is not serializable by ray
625+
) # CUDA tensor is not serializable by ray
627626
for k in rollouts[0].keys()
628627
}
629628
rollouts["consumer_global_step"] = self.consumer_global_step
@@ -758,8 +757,8 @@ async def loop(self) -> None:
758757
self.eval_mode = False
759758
self.latest_eval_step = self.consumer_global_step
760759
self.profiler.enter("rollout")
761-
# breakpoint()
762760
outputs = await self.rollout(**batch)
761+
outputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in outputs.items()}
763762
self.profiler.exit("rollout")
764763
outputs["temperature"] = torch.tensor(
765764
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
@@ -803,6 +802,8 @@ async def loop(self) -> None:
803802
outputs.pop("gt_answer")
804803
if "test_cases" in outputs:
805804
outputs.pop("test_cases")
805+
if "consumer_global_step" in outputs:
806+
outputs.pop("consumer_global_step")
806807
self.profiler.exit("calculate_reward")
807808

808809
print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}")

0 commit comments

Comments
 (0)