11import copy
2+ import os
23import uuid
34from typing import Any , Dict , Optional
45
@@ -64,24 +65,119 @@ def launch_distributed(
6465 core_consumer = ALGO_MAP .get (core_algo , SimpleConsumer )
6566
6667 train_dp_size = get_dp_size_fast (num_consumer_procs , plugin_config )
68+ print (f"inference_batch_size { inference_batch_size } num_producers { num_producers } train_batch_size { train_batch_size } train_dp_size { train_dp_size } " )
6769 assert (inference_batch_size * num_producers ) % (train_batch_size * train_dp_size ) == 0
6870
6971 dataset_path = train_dataset_config ["path" ]
7072 num_samples = get_jsonl_size_fast (dataset_path )
7173 global_inference_batch_size = inference_batch_size * num_producers
7274 num_update_per_episode = num_samples // global_inference_batch_size
7375 num_recv_per_update = inference_batch_size // inference_microbatch_size
74-
76+
7577 run_name = f"{ inference_backend } _bs_{ train_batch_size * train_dp_size } _temp_{ generate_config ['temperature' ]:.01f} _top_p_{ generate_config ['top_p' ]:.02f} "
7678 wandb_group_name = str (uuid .uuid4 ())
7779 rollout_log_file = os .path .join (
7880 rollout_save_dir ,
7981 f"{ project_name .replace (' ' ,'_' )} _run_{ wandb_group_name } .jsonl" ,
8082 )
8183
82- procs = []
84+
85+ # ###########################################
86+ # # Old version, may lead colossalai init stuck in multinodes
87+ # ############################################
88+ # procs = []
89+ # for i in range(num_producers):
90+ # # producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
91+ # producer = SimpleProducer.options(num_cpus=1, resources={"NPU":num_proc_per_producer}).remote(
92+ # producer_idx=i,
93+ # num_producers=num_producers,
94+ # num_consumer_procs=num_consumer_procs,
95+ # num_episodes=num_episodes,
96+ # batch_size=inference_batch_size,
97+ # dataset_config=dataset_config,
98+ # dataloaders_config=dataloaders_config,
99+ # model_config=inference_model_config,
100+ # generate_config=generate_config,
101+ # tokenizer_config=tokenizer_config,
102+ # microbatch_size=inference_microbatch_size,
103+ # backend=inference_backend,
104+ # num_generations=num_generations,
105+ # consumer_plugin_config=plugin_config,
106+ # )
107+ # procs.append(producer)
108+ # generate_config_consumer = copy.deepcopy(generate_config)
109+ # generate_config_consumer.update(
110+ # dict(
111+ # backend=inference_backend,
112+ # )
113+ # )
114+ # for i in range(num_consumer_procs):
115+ # # consumer = core_consumer.options(num_gpus=1).remote(
116+ # consumer = core_consumer.options(num_cpus=1, resources={"NPU":1}).remote(
117+ # num_producers=num_producers,
118+ # num_episodes=num_episodes,
119+ # rank=i,
120+ # world_size=num_consumer_procs,
121+ # master_addr=master_addr,
122+ # master_port=master_port,
123+ # num_update_per_episode=num_update_per_episode,
124+ # num_recv_per_update=num_recv_per_update,
125+ # batch_size=train_batch_size,
126+ # model_config=train_model_config,
127+ # plugin_config=plugin_config,
128+ # minibatch_size=train_minibatch_size,
129+ # generate_config=generate_config_consumer,
130+ # grpo_config=grpo_config,
131+ # num_generations=num_generations,
132+ # project_name=project_name,
133+ # save_interval=save_interval,
134+ # save_dir=save_dir,
135+ # )
136+ # procs.append(consumer)
137+ # ray.get([p.setup.remote() for p in procs])
138+ # ray.get([p.loop.remote() for p in procs])
139+
140+ ###########################################
141+ # New version, assign master ip for colossalai & vllm respectively
142+ ###########################################
143+ nodes = ray .nodes ()
144+ node_info = {
145+ node ["NodeID" ]: {
146+ # "num_gpus": node["Resources"].get("GPU", 0),
147+ "num_gpus" : node ["Resources" ].get ("NPU" , 0 ),
148+ "address" : node ["NodeManagerAddress" ],
149+ } # Default to 0 if no GPUs are available
150+ for node in nodes
151+ }
152+ print (f"node_info { node_info } " )
153+ gpu_to_node_id = []
154+ gpu_to_ip_address = []
155+ for node_id in node_info :
156+ for idx in range (int (node_info [node_id ]["num_gpus" ])): # use num_gpus instead of num_npus
157+ gpu_to_node_id .append (node_id )
158+ gpu_to_ip_address .append (node_info [node_id ]["address" ])
159+ print (f"node_info { node_info } \n gpu_to_node_id { gpu_to_node_id } \n gpu_to_ip_address { gpu_to_ip_address } \n " )
160+
161+ producer_procs = []
162+
83163 for i in range (num_producers ):
84- producer = SimpleProducer .options (num_gpus = num_proc_per_producer ).remote (
164+ node_id = gpu_to_node_id [0 ]
165+ producer_ip_address = gpu_to_ip_address [0 ]
166+ for _ in range (num_proc_per_producer ):
167+ gpu_to_node_id .pop (0 )
168+ gpu_to_ip_address .pop (0 )
169+ print (f"Schedual Producer P[{ i } ] which requires { num_proc_per_producer } GPUs on node { producer_ip_address } " )
170+
171+ producer = SimpleProducer .options (
172+ # num_cpus=1,
173+ # num_cpus=num_proc_per_producer,
174+ num_gpus = 0 ,
175+ resources = {"NPU" :num_proc_per_producer },
176+ scheduling_strategy = ray .util .scheduling_strategies .NodeAffinitySchedulingStrategy (
177+ node_id = node_id ,
178+ soft = False ,
179+ ),
180+ ).remote (
85181 producer_idx = i ,
86182 num_producers = num_producers ,
87183 num_consumer_procs = num_consumer_procs ,
@@ -107,20 +203,36 @@ def launch_distributed(
107203 log_rollout_interval = log_rollout_interval ,
108204 rollout_log_file = rollout_log_file ,
109205 )
110- procs .append (producer )
206+ producer_procs .append (producer )
207+ ray .get ([p .setup .remote () for p in producer_procs ])
111208 generate_config_consumer = copy .deepcopy (generate_config )
112209 generate_config_consumer .update (
113210 dict (
114211 backend = inference_backend ,
115212 )
116213 )
214+ consumer_master_ip_address = gpu_to_ip_address [0 ]
215+ print (f"Use { consumer_master_ip_address } as master address for torch DDP." )
216+ consumer_procs = []
117217 for i in range (num_consumer_procs ):
118- consumer = core_consumer .options (num_gpus = 1 ).remote (
218+ node_id = gpu_to_node_id [0 ]
219+ consumer_ip_address = gpu_to_ip_address [0 ]
220+ gpu_to_node_id .pop (0 )
221+ gpu_to_ip_address .pop (0 )
222+ print (f"Schedual Consumer T[{ i } ] which requires 1 GPUs on node { consumer_ip_address } " )
223+ consumer = core_consumer .options (
224+ resources = {"NPU" :1 },
225+ scheduling_strategy = ray .util .scheduling_strategies .NodeAffinitySchedulingStrategy (
226+ node_id = node_id ,
227+ soft = False ,
228+ ),
229+ ).remote (
119230 num_producers = num_producers ,
120231 num_episodes = num_episodes ,
121232 rank = i ,
122233 world_size = num_consumer_procs ,
123- master_addr = master_addr ,
234+ # master_addr=master_addr,
235+ master_addr = consumer_master_ip_address ,
124236 master_port = master_port ,
125237 num_update_per_episode = num_update_per_episode ,
126238 num_recv_per_update = num_recv_per_update ,
@@ -137,6 +249,6 @@ def launch_distributed(
137249 run_name = run_name ,
138250 wandb_group_name = wandb_group_name ,
139251 )
140- procs .append (consumer )
141- ray .get ([p .setup .remote () for p in procs ])
142- ray .get ([p .loop .remote () for p in procs ])
252+ consumer_procs .append (consumer )
253+ ray .get ([p .setup .remote () for p in consumer_procs ])
254+ ray .get ([p .loop .remote () for p in ( producer_procs + consumer_procs ) ])
0 commit comments