|
1 | | -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 1 | +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
@@ -57,14 +57,12 @@ def broadcast_obj_from_pp_rank(obj: Any) -> Any: |
57 | 57 | # ------------------------------------------------------------------ |
58 | 58 | # 2. Identify the owning rank (the only rank with True flag) |
59 | 59 | # ------------------------------------------------------------------ |
60 | | - src_rank = None # Rank *inside* the PP group |
61 | | - for rank, flag in enumerate(obj_flags): |
62 | | - if flag: |
63 | | - src_rank = rank |
64 | | - break |
65 | | - |
66 | | - if src_rank is None: |
| 60 | + true_ranks = [rank for rank, flag in enumerate(obj_flags) if flag] |
| 61 | + if not true_ranks: |
67 | 62 | raise ValueError("Object must exist on at least one PP rank") |
| 63 | + if len(true_ranks) > 1: |
| 64 | + raise ValueError(f"Object present on multiple PP ranks: {true_ranks}") |
| 65 | + src_rank = true_ranks[0] |
68 | 66 |
|
69 | 67 | # ------------------------------------------------------------------ |
70 | 68 | # 3. Broadcast the object from the source rank to all ranks |
@@ -135,12 +133,11 @@ def broadcast_tensors_from_last_stage( |
135 | 133 | if is_pipeline_last_stage(ignore_virtual=True): |
136 | 134 | # Broadcast tensors from last stage |
137 | 135 | for name, tensor in tensors.items(): |
138 | | - if tensor is not None: |
139 | | - broadcasted_tensors[name] = broadcast_tensor( |
140 | | - tensor, current_rank, pp_group |
| 136 | + if tensor is None: |
| 137 | + raise ValueError( |
| 138 | + f"Last PP stage must provide tensor '{name}' for broadcast." |
141 | 139 | ) |
142 | | - else: |
143 | | - broadcasted_tensors[name] = None |
| 140 | + broadcasted_tensors[name] = broadcast_tensor(tensor, current_rank, pp_group) |
144 | 141 | else: |
145 | 142 | # Receive tensors on other stages |
146 | 143 | for name in tensors.keys(): |
|
0 commit comments