Skip to content

Commit f10f707

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 8ca76fe commit f10f707

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

applications/ColossalChat/coati/distributed/agent/agentic_producer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import copy
2-
import random
32
import re
43
from typing import Any, Dict
54
from uuid import uuid4

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ def launch_distributed(
8787
num_samples = get_jsonl_size_fast(dataset_path)
8888
global_inference_batch_size = inference_batch_size * num_producers
8989
num_update_per_episode = num_samples // global_inference_batch_size
90-
num_recv_per_update = inference_batch_size // inference_microbatch_size if "async-agentic" not in inference_backend else 1
90+
num_recv_per_update = (
91+
inference_batch_size // inference_microbatch_size if "async-agentic" not in inference_backend else 1
92+
)
9193

9294
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
9395
wandb_group_name = str(uuid.uuid4())

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,7 @@ class AsyncSimpleProducer(BaseAsyncProducer):
847847
Asyncronous version of the producer that uses vLLM for generation.
848848
This class is designed to handle multiple producer actors and distribute tasks among them.
849849
"""
850+
850851
@torch.no_grad()
851852
async def rollout(self, input_ids, attention_mask, **kwargs):
852853
# naive rollout strategy without load balancing

applications/ColossalChat/coati/distributed/utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import json
22
import os
3+
import random
34
from typing import Any, Dict, List
4-
import asyncio
5+
6+
import ray
57
import torch
68
from filelock import FileLock
7-
import random
9+
810
from colossalai.shardformer.layer.loss import dist_log_prob
9-
import ray
11+
1012

1113
def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
1214
batches = []
@@ -166,6 +168,7 @@ def safe_append_to_jsonl_file(file_path, data):
166168
json_line = json.dumps(entry, ensure_ascii=False)
167169
f.write(json_line + "\n")
168170

171+
169172
@ray.remote
170173
class LoadBalancer:
171174
def __init__(self, worker_counts):
@@ -180,10 +183,9 @@ def get_next_worker(self, worker_type, amount=1):
180183
chosen = random.choice(candidates)
181184
self.load[worker_type][chosen] += amount
182185
return chosen, self.load[worker_type]
183-
186+
184187
def increase_load(self, worker_type, worker_id, amount=1):
185188
self.load[worker_type][worker_id] += amount
186-
189+
187190
def decrease_load(self, worker_type, worker_id, amount=1):
188191
self.load[worker_type][worker_id] -= amount
189-

0 commit comments

Comments
 (0)