Skip to content

Commit d843f02

Browse files
fix: Fix policy worker placement when using unified placement group (#1341)
Signed-off-by: Guyue Huang <[email protected]> Signed-off-by: Guyue Huang <[email protected]> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent 3a69c21 commit d843f02

File tree

10 files changed

+176
-8
lines changed

10 files changed

+176
-8
lines changed

nemo_rl/algorithms/dpo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ def setup(
248248
init_optimizer=True,
249249
init_reference_model=True,
250250
)
251+
# print the node IP and GPU ID of the policy workers for debugging
252+
policy.print_node_ip_and_gpu_id()
253+
251254
loss_fn = DPOLossFn(master_config["dpo"])
252255
print(" ✓ Model initialized")
253256

nemo_rl/algorithms/grpo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,8 @@ def setup(
482482
optimizer_path=optimizer_path,
483483
init_optimizer=True,
484484
)
485+
# print the node IP and GPU ID of the policy workers for debugging
486+
policy.print_node_ip_and_gpu_id()
485487

486488
# if it is not colocated inference, initialize collective communication for update weights
487489
if not colocated_inference:

nemo_rl/algorithms/rm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ def setup(
223223
init_optimizer=True,
224224
init_reference_model=False,
225225
)
226+
# print the node IP and GPU ID of the policy workers for debugging
227+
policy.print_node_ip_and_gpu_id()
228+
226229
loss_fn = PreferenceLoss()
227230
print(" ✓ Model initialized")
228231

nemo_rl/algorithms/sft.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ def setup(
202202
init_optimizer=True,
203203
init_reference_model=False,
204204
)
205+
# print the node IP and GPU ID of the policy workers for debugging
206+
policy.print_node_ip_and_gpu_id()
207+
205208
loss_fn = NLLLoss()
206209
print(" ✓ Model initialized")
207210

nemo_rl/distributed/virtual_cluster.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ray.util.placement_group import (
2222
PlacementGroup,
2323
placement_group,
24+
placement_group_table,
2425
remove_placement_group,
2526
)
2627
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@@ -170,6 +171,14 @@ def init_ray(log_dir: Optional[str] = None) -> None:
170171
)
171172

172173

174+
@ray.remote(num_gpus=1)
175+
class GetGPUIDActor: # pragma: no cover
176+
"""Util actor class to return GPU id of the current worker."""
177+
178+
def get_gpu_id(self):
179+
return ray.get_gpu_ids()[0]
180+
181+
173182
class ResourceInsufficientError(Exception):
174183
"""Exception raised when the cluster does not have enough resources to satisfy the requested configuration."""
175184

@@ -210,6 +219,7 @@ def __init__(
210219
self._bundle_ct_per_node_list = bundle_ct_per_node_list
211220
self._world_size = sum(self._bundle_ct_per_node_list)
212221
self._node_placement_groups: Optional[list[PlacementGroup]] = None
222+
self._sorted_bundle_indices: Optional[list[int]] = None
213223

214224
self.num_gpus_per_node = num_gpus_per_node
215225
self.use_gpus = use_gpus
@@ -251,6 +261,8 @@ def _init_placement_groups(
251261
self._node_placement_groups = self._create_placement_groups_internal(
252262
strategy, use_unified_pg
253263
)
264+
if use_unified_pg and self.use_gpus:
265+
self._sorted_bundle_indices = self._get_sorted_bundle_indices()
254266
return self._node_placement_groups
255267
except ResourceInsufficientError as e:
256268
print(e)
@@ -402,8 +414,66 @@ def get_master_address_and_port(self) -> tuple[str, int]:
402414
Returns:
403415
Tuple of (address, port)
404416
"""
417+
# Get placement groups if not already created
418+
if not self._node_placement_groups:
419+
self.get_placement_groups()
420+
421+
# If sorted bundle indices are available, get the address and port for the first bundle index
422+
if self._sorted_bundle_indices is not None:
423+
return self.get_available_address_and_port(
424+
pg_idx=0, bundle_idx=self._sorted_bundle_indices[0]
425+
)
426+
427+
# Otherwise, get the address and port for bundle index 0
405428
return self.get_available_address_and_port(pg_idx=0, bundle_idx=0)
406429

430+
def _get_sorted_bundle_indices(self) -> Optional[list[int]]:
431+
"""Gets the sorted bundle indices for the placement groups."""
432+
if self._node_placement_groups is None:
433+
raise ValueError(
434+
"Placement groups must be initialized before calling _get_sorted_bundle_indices"
435+
)
436+
437+
if not self.use_gpus:
438+
return None
439+
440+
if len(self._node_placement_groups) != 1:
441+
return None
442+
443+
pg = self._node_placement_groups[0]
444+
pg_data = placement_group_table(pg)
445+
num_bundles = len(pg_data["bundles"])
446+
bundle_to_node_ids = pg_data["bundles_to_node_id"]
447+
448+
# use info actor to get the GPU id
449+
info_actors = []
450+
for i in range(num_bundles):
451+
info_actors.append(
452+
GetGPUIDActor.options(
453+
num_cpus=0.01, # set both num_cpus and num_gpus to be small values to enable assignment in colocated case
454+
num_gpus=0.01,
455+
resources=None,
456+
scheduling_strategy=PlacementGroupSchedulingStrategy(
457+
placement_group=pg,
458+
placement_group_bundle_index=i,
459+
),
460+
).remote()
461+
)
462+
463+
gpu_ids = ray.get([actor.get_gpu_id.remote() for actor in info_actors])
464+
for actor in info_actors:
465+
ray.kill(actor)
466+
467+
# original index, node_id, gpu_id
468+
bundle_infos = [
469+
(i, bundle_to_node_ids[i], gpu_ids[i]) for i in range(num_bundles)
470+
]
471+
pg_reordered_bundle_indices = [
472+
bundle_info[0]
473+
for bundle_info in sorted(bundle_infos, key=lambda x: (x[1], x[2]))
474+
] # sort by node_id, then gpu_id
475+
return pg_reordered_bundle_indices
476+
407477
def shutdown(self) -> bool:
408478
"""Cleans up and releases all resources associated with this virtual cluster.
409479

nemo_rl/models/policy/dtensor_policy_worker.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,3 +1929,9 @@ def start_gpu_profiling(self) -> None:
19291929
def stop_gpu_profiling(self) -> None:
19301930
"""Stop GPU profiling."""
19311931
torch.cuda.profiler.stop()
1932+
1933+
def report_node_ip_and_gpu_id(self) -> list[tuple[str, int]]:
1934+
"""Report the node IP and GPU ID of the current worker."""
1935+
ip = ray._private.services.get_node_ip_address()
1936+
gpu_id = ray.get_gpu_ids()[0]
1937+
return (ip, gpu_id)

nemo_rl/models/policy/dtensor_policy_worker_v2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1913,3 +1913,9 @@ def start_gpu_profiling(self) -> None:
19131913
def stop_gpu_profiling(self) -> None:
19141914
"""Stop GPU profiling."""
19151915
torch.cuda.profiler.stop()
1916+
1917+
def report_node_ip_and_gpu_id(self) -> list[tuple[str, int]]:
1918+
"""Report the node IP and GPU ID of the current worker."""
1919+
ip = ray._private.services.get_node_ip_address()
1920+
gpu_id = ray.get_gpu_ids()[0]
1921+
return (ip, gpu_id)

nemo_rl/models/policy/lm_policy.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,33 @@ def __init__(
167167
pre_init_communication_queue=pre_init_queue,
168168
)
169169

170-
self.worker_group = RayWorkerGroup(
171-
cluster,
172-
worker_builder,
173-
name_prefix=name_prefix,
174-
workers_per_node=workers_per_node,
175-
sharding_annotations=self.sharding_annotations,
176-
env_vars=env_vars or {},
177-
)
170+
if cluster._sorted_bundle_indices is not None:
171+
# The cluster has initialized a unified placemenet group across nodes
172+
# In this case, we need to create workers based on sorted bundle indices
173+
group_size = cluster.num_gpus_per_node
174+
tied_groups = [
175+
(i // group_size, [bundle_idx])
176+
for i, bundle_idx in enumerate(cluster._sorted_bundle_indices)
177+
]
178+
179+
self.worker_group = RayWorkerGroup(
180+
cluster,
181+
worker_builder,
182+
name_prefix=name_prefix,
183+
bundle_indices_list=tied_groups,
184+
sharding_annotations=self.sharding_annotations,
185+
env_vars=env_vars or {},
186+
)
187+
188+
else:
189+
self.worker_group = RayWorkerGroup(
190+
cluster,
191+
worker_builder,
192+
name_prefix=name_prefix,
193+
workers_per_node=workers_per_node,
194+
sharding_annotations=self.sharding_annotations,
195+
env_vars=env_vars or {},
196+
)
178197

179198
if config["dynamic_batching"]["enabled"]:
180199
assert pp_size == 1, (
@@ -755,3 +774,36 @@ def stop_gpu_profiling(self) -> None:
755774
"""Stop GPU profiling."""
756775
futures = self.worker_group.run_all_workers_single_data("stop_gpu_profiling")
757776
ray.get(futures)
777+
778+
def print_node_ip_and_gpu_id(self) -> list[tuple[str, int]]:
779+
"""Print the node IP and GPU ID of the current worker."""
780+
results = ray.get(
781+
self.worker_group.run_all_workers_single_data(
782+
"report_node_ip_and_gpu_id",
783+
)
784+
)
785+
all_node_ips = sorted(set([result[0] for result in results]))
786+
all_gpu_ids = sorted(set([result[1] for result in results]))
787+
788+
worker_id_list = [
789+
[list() for _ in range(len(all_gpu_ids))] for _ in range(len(all_node_ips))
790+
]
791+
for worker_id, (ip, gpu_id) in enumerate(results):
792+
node_idx = all_node_ips.index(ip)
793+
gpu_idx = all_gpu_ids.index(gpu_id)
794+
worker_id_list[node_idx][gpu_idx].append("worker-" + str(worker_id))
795+
796+
from prettytable import PrettyTable
797+
798+
table = PrettyTable()
799+
table.title = "Policy worker mapping to Nodes and GPUs"
800+
table.field_names = ["Node_IP"] + [
801+
"GPU_ID=" + str(gpu_id) for gpu_id in all_gpu_ids
802+
]
803+
for i, node_idx in enumerate(all_node_ips):
804+
row = [node_idx]
805+
for j in range(len(all_gpu_ids)):
806+
row.append(tuple(worker_id_list[i][j]))
807+
table.add_row(row)
808+
809+
print(table)

nemo_rl/models/policy/megatron_policy_worker.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2228,3 +2228,9 @@ def start_gpu_profiling(self) -> None:
22282228
def stop_gpu_profiling(self) -> None:
22292229
"""Stop GPU profiling."""
22302230
torch.cuda.profiler.stop()
2231+
2232+
def report_node_ip_and_gpu_id(self) -> list[tuple[str, int]]:
2233+
"""Report the node IP and GPU ID of the current worker."""
2234+
ip = ray._private.services.get_node_ip_address()
2235+
gpu_id = ray.get_gpu_ids()[0]
2236+
return (ip, gpu_id)

tests/unit/distributed/test_virtual_cluster.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,20 @@ def test_mcore_py_executable():
229229
assert "megatron-bridge is imported" in result.stdout
230230
assert "megatron-core is imported" in result.stdout
231231
assert "megatron-training is imported" in result.stdout
232+
233+
234+
def test_create_sorted_bundle_indices_for_unified_pg():
235+
"""Test that sorted bundle indices are created for a unified placement group."""
236+
cluster = RayVirtualCluster(bundle_ct_per_node_list=[2], use_gpus=True)
237+
cluster._init_placement_groups(strategy=None, use_unified_pg=True)
238+
assert cluster._sorted_bundle_indices is not None
239+
assert len(cluster._sorted_bundle_indices) == 2
240+
assert 0 in cluster._sorted_bundle_indices
241+
assert 1 in cluster._sorted_bundle_indices
242+
243+
244+
def test_not_create_sorted_bundle_indices_for_per_node_pg():
245+
"""Test that sorted bundle indices are not created for a per-node placement group."""
246+
cluster = RayVirtualCluster(bundle_ct_per_node_list=[2], use_gpus=True)
247+
cluster._init_placement_groups(strategy=None, use_unified_pg=False)
248+
assert cluster._sorted_bundle_indices is None

0 commit comments

Comments
 (0)