Skip to content

Commit 8614cd3

Browse files
[None][fix] fix: resolve GPU memory imbalance in concurrent weight loading (#6472)
Signed-off-by: Necofish <[email protected]> Signed-off-by: Nekofish-L <[email protected]> Signed-off-by: Jie Li <[email protected]> Co-authored-by: Jie Li <[email protected]>
1 parent e2891a6 commit 8614cd3

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

examples/llm-api/quickstart_advanced.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def add_llm_args(parser):
7777
choices=["auto", "TorchSampler", "TRTLLMSampler"])
7878
parser.add_argument('--tp_size', type=int, default=1)
7979
parser.add_argument('--pp_size', type=int, default=1)
80+
parser.add_argument('--orchestrator_type',
81+
type=str,
82+
default=None,
83+
choices=[None, 'rpc', 'ray'],
84+
help='Orchestrator type for multi-GPU execution')
8085
parser.add_argument('--moe_ep_size', type=int, default=-1)
8186
parser.add_argument('--moe_tp_size', type=int, default=-1)
8287
parser.add_argument('--moe_cluster_size', type=int, default=-1)
@@ -288,6 +293,7 @@ def setup_llm(args, **kwargs):
288293
trust_remote_code=args.trust_remote_code,
289294
gather_generation_logits=args.return_generation_logits,
290295
max_beam_width=args.max_beam_width,
296+
orchestrator_type=args.orchestrator_type,
291297
**kwargs)
292298

293299
use_beam_search = args.max_beam_width > 1

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch.utils._pytree import tree_any_only
1313
from tqdm import tqdm
1414

15+
from tensorrt_llm._utils import local_mpi_rank
1516
from tensorrt_llm.lora_manager import HfLoraLoader
1617
from tensorrt_llm.models.convert_utils import split_matrix_tp
1718

@@ -852,8 +853,10 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
852853
'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
853854
'gate_up_proj': ['gate_proj', 'up_proj']
854855
}
856+
device_id = local_mpi_rank()
855857

856858
def load_single_module(name, module):
859+
torch.cuda.set_device(device_id)
857860
if len(module._parameters) > 0:
858861
# skip load weights if module is in skip_modules
859862
if any(skip_module in name for skip_module in skip_modules):
@@ -931,7 +934,7 @@ def load_single_module(name, module):
931934
p.data.copy_(module_weights[n][:])
932935

933936
if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
934-
"True") in ["True", "true", "1", "yes", "y"]:
937+
"False") in ["True", "true", "1", "yes", "y"]:
935938
for name, module in tqdm(list(
936939
model.named_modules(remove_duplicate=False)),
937940
desc="Loading weights"):
@@ -977,8 +980,10 @@ def _load_weights_impl_v2(model: Union[nn.Module, DecoderModelForCausalLM],
977980
if params_map is not None:
978981
weights = weight_mapper.rename_by_params_map(params_map, weights)
979982
logger.info(f"Renamed weights with params_map: {params_map}")
983+
device_id = local_mpi_rank()
980984

981985
def load_single_module(name, module):
986+
torch.cuda.set_device(device_id)
982987
if len(module._parameters) > 0:
983988
if weight_mapper.should_skip_module(name):
984989
return
@@ -1034,7 +1039,7 @@ def load_single_module(name, module):
10341039
allow_partial_loading=allow_partial_loading)
10351040

10361041
if os.environ.get("TRT_LLM_DISABLE_LOAD_WEIGHTS_IN_PARALLEL",
1037-
"True") in ["True", "true", "1", "yes", "y"]:
1042+
"False") in ["True", "true", "1", "yes", "y"]:
10381043
for name, module in tqdm(list(
10391044
model.named_modules(remove_duplicate=False)),
10401045
desc="Loading weights"):

tensorrt_llm/_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,15 @@ def mpi_world_size():
561561

562562

563563
def local_mpi_rank():
564-
return local_comm.Get_rank() if ENABLE_MULTI_DEVICE else 0
564+
if mpi_disabled():
565+
# For Ray/non-MPI: the device was already set during worker init
566+
# torch.cuda.current_device() returns the correct local device ID
567+
try:
568+
return torch.cuda.current_device()
569+
except ValueError:
570+
return 0
571+
return mpi_comm().Get_rank() % torch.cuda.device_count(
572+
) if ENABLE_MULTI_DEVICE else 0
565573

566574

567575
def local_mpi_size():

0 commit comments

Comments
 (0)