Skip to content

Commit b2095aa

Browse files
authored
[#4674][bugfix] AutoDeploy Fix memory leak in fuse_moe (#7844)
Delete the unstacked weights immediately to save GPU memory, cleanup occurs automatically after the transformation, but for large models we'll run out of memory during the transformation itself. Signed-off-by: Gal Hubara Agam <[email protected]>
1 parent 20e6cd3 commit b2095aa

File tree

3 files changed

+55
-1
lines changed

3 files changed

+55
-1
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ transforms:
107107
backend: trtllm
108108
fuse_moe:
109109
stage: post_load_fusion
110-
enabled: false # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
110+
enabled: true
111111
fuse_allreduce_residual_rmsnorm:
112112
stage: post_load_fusion
113113
fuse_collectives:

tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int:
5757
node.replace_all_uses_with(new_node)
5858
graph.erase_node(node)
5959

60+
# Delete the unstacked weights immediately to save GPU memory
61+
# This will happen automatically after the graph is canonicalized, but for large models we'll run out of memory
62+
# during the transformation itself.
63+
gm.graph.eliminate_dead_code()
64+
gm.delete_all_unused_submodules()
65+
6066
return fused_key_counter
6167

6268

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,51 @@ def test_moe_fusion():
368368
num_param_nodes_fused < num_param_nodes
369369
), f"""number of parameter nodes after fusion {num_param_nodes_fused} <
370370
number of parameter nodes before fusion {num_param_nodes}"""
371+
372+
373+
def test_fuse_moe_cleanup():
374+
# Ensure deterministic allocations and a clean slate
375+
torch.manual_seed(1234)
376+
torch.cuda.manual_seed(1234)
377+
torch.cuda.empty_cache()
378+
379+
device = "cuda"
380+
dtype = torch.bfloat16
381+
382+
# Build model and export to GraphModule (pre-fusion)
383+
model = MoEOpModel().to(device=device, dtype=dtype)
384+
x = model.get_input(device=device, dtype=dtype)
385+
gm = torch_export_to_gm(model, args=(x,), clone=True)
386+
387+
# Count parameters and measure memory before fusion
388+
num_param_nodes_before = len(list(gm.named_parameters()))
389+
torch.cuda.synchronize()
390+
torch.cuda.empty_cache()
391+
mem_before = torch.cuda.memory_allocated()
392+
393+
# Apply MoE fusion which should stack weights and clean up unstacked params
394+
# We need to ensure the cleanup is done as part of the transformation to avoid OOM during the transformation itself.
395+
gm_transformed = InferenceOptimizer(
396+
None,
397+
{
398+
"fuse_moe": {
399+
"stage": "post_load_fusion",
400+
"run_graph_cleanup": False, # verify cleanup is done as part of the transformation
401+
"run_shape_prop": False, # shape_prop can also trigger cleanup
402+
},
403+
},
404+
)(None, gm)
405+
406+
# Ensure that parameter count decreased after fusion (unstacked params cleaned)
407+
num_param_nodes_after = len(list(gm_transformed.named_parameters()))
408+
assert num_param_nodes_after < num_param_nodes_before, (
409+
f"Expected fewer parameters after fusion: before={num_param_nodes_before}, after={num_param_nodes_after}"
410+
)
411+
412+
# Memory should not increase after fusion/cleanup
413+
torch.cuda.synchronize()
414+
torch.cuda.empty_cache()
415+
mem_after = torch.cuda.memory_allocated()
416+
assert mem_after <= mem_before, (
417+
f"CUDA memory increased after fusion: before={mem_before} after={mem_after}"
418+
)

0 commit comments

Comments
 (0)