Skip to content

Commit 6bc3d6f

Browse files
[Graph Partition] fix partition x memory plan issue (pytorch#166984)
[Graph Partition] fix partition x memory plan issue (pytorch#165514) For `test_graph_partition_with_memory_plan_reuse`, before this PR, when using graph partition, it would error ([P1992728479](https://www.internalfb.com/phabricator/paste/view/P1992728479)): ``` def partition_0(args): ... del buf0 return (buf3, buf4, buf5, buf2, primals_4, ) ... File "/tmp/torchinductor_boyuan/ww/cwwc7ukfqscg2vy6ankby2fizdb377tvgyx3fwdgddrxe3g47jg6.py", line 132, in partition_0 return (buf3, buf4, buf5, buf2, primals_4, ) ^^^^ NameError: name 'buf2' is not defined. Did you mean: 'buf0'? ``` When not using graph partition, it would work and give the following code ([P1992997521](https://www.internalfb.com/phabricator/paste/view/P1992997521)): ``` def call(self, args): ... buf2 = buf0; del buf0 # reuse ... ``` Note that the issue is buf0 is not reused for buf2 when using graph partition. Why? Because the codegen runs `run_wrapper_ir_passes` and `memory_plan_reuse`, which pops tailing `MemoryPlanningLine` unless it is in graph output by checking `V.graph.get_output_names()`. However, for graph partition, we should check the output of the current partition instead of the graph before partition. Pull Request resolved: pytorch#165514 Approved by: https://github.com/ProExpertProg, https://github.com/eellison (cherry picked from commit f071f17) Co-authored-by: Boyuan Feng <[email protected]>
1 parent ba86395 commit 6bc3d6f

File tree

3 files changed

+126
-3
lines changed

3 files changed

+126
-3
lines changed

test/inductor/test_cudagraph_trees.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,125 @@ def f(x):
985985
num_partitions = get_num_partitions(code)
986986
self.assertEqual(num_partitions, 2)
987987

988+
@torch._inductor.config.patch("graph_partition", True)
989+
@torch._inductor.config.patch("implicit_fallbacks", True)
990+
def test_graph_partition_with_memory_plan_reuse(self):
991+
BATCH_SIZE = 16
992+
MLP_SIZE = 128
993+
HIDDEN_SIZE = 128
994+
RANDOM_SEED = 0
995+
996+
@torch.library.custom_op(
997+
"silly::attention",
998+
mutates_args=["out"],
999+
tags=(torch._C.Tag.cudagraph_unsafe,),
1000+
)
1001+
def attention(
1002+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
1003+
) -> None:
1004+
out.copy_(q + k + v)
1005+
1006+
@attention.register_fake
1007+
def _(q, k, v, out):
1008+
return None
1009+
1010+
class ParentModel(torch.nn.Module):
1011+
def __init__(self) -> None:
1012+
super().__init__()
1013+
1014+
def forward(self, x: torch.Tensor) -> torch.Tensor:
1015+
return x
1016+
1017+
class Attention(torch.nn.Module):
1018+
def __init__(self, mlp_size: int, hidden_size: int) -> None:
1019+
super().__init__()
1020+
self.pre_attn = torch.nn.Linear(mlp_size, hidden_size, bias=False)
1021+
self.post_attn = torch.nn.Linear(hidden_size, mlp_size, bias=False)
1022+
self.rms_norm_weight = torch.nn.Parameter(torch.ones(hidden_size))
1023+
1024+
def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor:
1025+
x_f32 = x.float()
1026+
return (
1027+
x_f32
1028+
* torch.rsqrt(
1029+
torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6
1030+
)
1031+
* self.rms_norm_weight
1032+
).to(x.dtype)
1033+
1034+
def forward(self, x: torch.Tensor) -> torch.Tensor:
1035+
x = self.pre_attn(x)
1036+
x = self.rms_norm_ref(x)
1037+
attn_output = torch.empty_like(x)
1038+
torch.ops.silly.attention(x, x, x, attn_output)
1039+
x = attn_output
1040+
x = self.rms_norm_ref(x)
1041+
x = self.post_attn(x)
1042+
return x
1043+
1044+
class CompiledAttention(torch.nn.Module):
1045+
def __init__(
1046+
self,
1047+
*,
1048+
mlp_size: int,
1049+
hidden_size: int,
1050+
) -> None:
1051+
super().__init__()
1052+
self.attn = Attention(mlp_size, hidden_size)
1053+
1054+
def forward(self, x: torch.Tensor) -> torch.Tensor:
1055+
return self.attn(x)
1056+
1057+
class CompiledAttentionTwo(CompiledAttention):
1058+
def forward(self, x: torch.Tensor) -> torch.Tensor:
1059+
return self.attn(x) + x
1060+
1061+
class SimpleModelWithTwoGraphs(ParentModel):
1062+
def __init__(
1063+
self,
1064+
*,
1065+
mlp_size: int,
1066+
hidden_size: int,
1067+
) -> None:
1068+
super().__init__()
1069+
self.attn_one = CompiledAttention(
1070+
mlp_size=mlp_size,
1071+
hidden_size=hidden_size,
1072+
)
1073+
self.attn_two = CompiledAttentionTwo(
1074+
mlp_size=mlp_size,
1075+
hidden_size=hidden_size,
1076+
)
1077+
1078+
self.hidden_states = torch.zeros((BATCH_SIZE, MLP_SIZE)).cuda()
1079+
1080+
def forward(self, x: torch.Tensor) -> torch.Tensor:
1081+
bsz = x.shape[0]
1082+
# CUDAGraph expects same tensor addresses for each run
1083+
self.hidden_states[:bsz].copy_(x)
1084+
x = self.attn_one(self.hidden_states[:bsz])
1085+
self.hidden_states[:bsz].copy_(x)
1086+
x = self.attn_two(self.hidden_states[:bsz])
1087+
return x
1088+
1089+
eager_model = (
1090+
SimpleModelWithTwoGraphs(
1091+
mlp_size=MLP_SIZE,
1092+
hidden_size=HIDDEN_SIZE,
1093+
)
1094+
.eval()
1095+
.cuda()
1096+
)
1097+
1098+
compiled_model = torch.compile(eager_model, mode="reduce-overhead")
1099+
1100+
inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda()
1101+
1102+
for _ in range(3):
1103+
eager_out = eager_model(inputs)
1104+
compiled_out = compiled_model(inputs)
1105+
self.assertEqual(eager_out, compiled_out)
1106+
9881107
@torch._inductor.config.patch("graph_partition", True)
9891108
@torch._inductor.config.patch("triton.cudagraph_trees", False)
9901109
def test_graph_partition_gc(self):

torch/_inductor/codegen/wrapper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1700,7 +1700,8 @@ def memory_plan(self):
17001700
self.lines = MemoryPlanner(self).plan(self.lines)
17011701

17021702
def memory_plan_reuse(self):
1703-
out_names = V.graph.get_output_names()
1703+
outputs = self.get_graph_outputs()
1704+
out_names = V.graph._get_output_names(outputs)
17041705

17051706
while (
17061707
self.lines

torch/_inductor/graph.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2410,11 +2410,11 @@ def _compile_to_module_lines(
24102410

24112411
return mod
24122412

2413-
def get_output_names(self) -> list[str]:
2413+
def _get_output_names(self, graph_outputs: list[ir.IRNode]) -> list[str]:
24142414
names = []
24152415
shape_counter = itertools.count(0)
24162416
none_counter = itertools.count(0)
2417-
for node in self.graph_outputs:
2417+
for node in graph_outputs:
24182418
if isinstance(node, ir.NoneAsConstantBuffer):
24192419
names.append(f"{self.name}_none{next(none_counter)}")
24202420
elif isinstance(node, ir.ShapeAsConstantBuffer):
@@ -2423,6 +2423,9 @@ def get_output_names(self) -> list[str]:
24232423
names.append(node.get_name())
24242424
return names
24252425

2426+
def get_output_names(self) -> list[str]:
2427+
return self._get_output_names(self.graph_outputs)
2428+
24262429
def is_unspec_arg(self, name: str) -> bool:
24272430
# dynamo wraps unspec variable as 0d CPU tensor,
24282431
# need to convert to scalar during codegen (triton only)

0 commit comments

Comments
 (0)