Skip to content

Commit 0f566cc

Browse files
author
Tong Li
committed
add algo selection
1 parent 812f4b7 commit 0f566cc

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,15 @@
22

33
import ray
44

5+
from .consumer import SimpleConsumer
56
from .grpo_consumer import GRPOConsumer
67
from .producer import SimpleProducer
78

9+
ALGO_MAP = {
10+
"Simple": SimpleConsumer,
11+
"GRPO": GRPOConsumer,
12+
}
13+
814

915
def get_jsonl_size_fast(path: str) -> int:
1016
with open(path) as f:
@@ -40,7 +46,14 @@ def launch_distributed(
4046
inference_backend: str = "transformers",
4147
master_addr: str = "localhost",
4248
master_port: int = 29500,
49+
core_algo: str = "GRPO",
4350
):
51+
52+
if core_algo not in ALGO_MAP:
53+
raise NotImplementedError(f"{core_algo} is not supported yet.")
54+
else:
55+
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
56+
4457
train_dp_size = get_dp_size_fast(num_producers, plugin_config)
4558
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
4659

@@ -68,7 +81,7 @@ def launch_distributed(
6881
)
6982
procs.append(producer)
7083
for i in range(num_consumer_procs):
71-
consumer = GRPOConsumer.options(num_gpus=1).remote(
84+
consumer = core_consumer.options(num_gpus=1).remote(
7285
num_producers=num_producers,
7386
num_episodes=num_episodes,
7487
rank=i,

applications/ColossalChat/rl_example.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
parser.add_argument("-tbs", "--train-batch-size", type=int, default=16)
1616
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
1717
parser.add_argument("-b", "--backend", type=str, default="transformers")
18+
parser.add_argument("-a", "--algo", type=str, default="GPRO", choices=["Simple, GPRO"])
1819
args = parser.parse_args()
1920

2021
ray.init(address="local", namespace="ray-example")
@@ -95,4 +96,5 @@
9596
inference_backend=args.backend,
9697
master_addr="localhost",
9798
master_port=29504,
99+
core_algo=args.algo
98100
)

0 commit comments

Comments
 (0)