Skip to content

Commit e355217

Browse files
srinathavaSrinath Avadhanula
andauthored
Use all frames of the stack trace when importing (#4075)
We currently use the first frame of the stack_trace when importing a node into MLIR. This causes modules with deeply nested ops to lose most useful information. This recovers all the stack frames (at the expected cost of an increase in the MLIR size). This also seems to be how we were originally importing from TorchScript. For an example module like this (in `/tmp/mode.py`): ```python def add_fp32_loader() -> RCPayload: class AddFP32Net(torch.nn.Module): def __init__(self): super().__init__() def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: def bar(x): return x + 1.0 def foo(x1, x2): return bar(x1) + bar(x2) z1 = foo(inputs["x"], inputs["y"]) return {"z1": z1} ``` if we import this, we now get: ```mlir #loc1 = loc("compile.py":1332:0) module { func.func @add_fp32(%arg0: !torch.vtensor<[128,128],f32> loc("compile.py":1332:0)) -> !torch.vtensor<[128,128],f32> attributes {torch.assume_strict_symbolic_shapes} { %none = torch.constant.none loc(#loc1) %0 = torch.aten.clone %arg0, %none : !torch.vtensor<[128,128],f32>, !torch.none -> !torch.vtensor<[128,128],f32> loc(#loc1) %none_0 = torch.constant.none loc(#loc1) %1 = torch.aten.clone %arg1, %none_0 : !torch.vtensor<[128,128],f32>, !torch.none -> !torch.vtensor<[128,128],f32> loc(#loc1) %float1.000000e00 = torch.constant.float 1.000000e+00 loc(#loc10) %int1 = torch.constant.int 1 loc(#loc10) %2 = torch.aten.add.Scalar %0, %float1.000000e00, %int1 : !torch.vtensor<[128,128],f32>, !torch.float, !torch.int -> !torch.vtensor<[128,128],f32> loc(#loc10) %float1.000000e00_1 = torch.constant.float 1.000000e+00 loc(#loc10) %int1_2 = torch.constant.int 1 loc(#loc10) %3 = torch.aten.add.Scalar %1, %float1.000000e00_1, %int1_2 : !torch.vtensor<[128,128],f32>, !torch.float, !torch.int -> !torch.vtensor<[128,128],f32> loc(#loc10) %int1_3 = torch.constant.int 1 loc(#loc9) %4 = torch.aten.add.Tensor %2, %3, %int1_3 : !torch.vtensor<[128,128],f32>, !torch.vtensor<[128,128],f32>, !torch.int -> !torch.vtensor<[128,128],f32> loc(#loc9) return %4 : !torch.vtensor<[128,128],f32> loc(#loc1) } loc(#loc1) } loc(#loc) #loc = loc(unknown) #loc2 = loc("/tmp/model.py":17:0) #loc3 = loc("/tmp/model.py":20:0) #loc4 = loc("/tmp/model.py":22:0) #loc5 = loc("torch/nn/modules/module.py":1562:0) #loc6 = loc("compile.py":1333:0) #loc7 = loc(callsite(#loc5 at #loc6)) #loc8 = loc(callsite(#loc4 at #loc7)) #loc9 = loc(callsite(#loc3 at #loc8)) #loc10 = loc(callsite(#loc2 at #loc9)) ``` Originally, all ops would have a single location pointing to `compile.py` (the frame from where we initiated the import). Inlining the locations for the final `aten.add.Tensor` op gives us: ``` #loc9: "/tmp/model.py":20:0 "/tmp/model.py":22:0 "torch/nn/modules/module.py":1562:0 "compile.py":1333:0 ``` --------- Co-authored-by: Srinath Avadhanula <[email protected]>
1 parent fb14944 commit e355217

File tree

2 files changed

+40
-4
lines changed

2 files changed

+40
-4
lines changed

python/torch_mlir/extras/fx_importer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,10 +1174,16 @@ def get_node_location(self, node: torch_fx.Node) -> Optional[Location]:
11741174
# https://github.com/pytorch/pytorch/issues/91000
11751175
stack_trace = node.stack_trace
11761176
if stack_trace:
1177-
m = re.search(r"""File "([^"]+)", line ([0-9]+),""", stack_trace)
1178-
if m:
1179-
filename, line = m.group(1), int(m.group(2))
1180-
return Location.file(filename, line, col=0, context=self._c)
1177+
matches = re.findall(r"""File "([^"]+)", line ([0-9]+),""", stack_trace)
1178+
locations = [
1179+
Location.file(m[0], int(m[1]), col=0, context=self._c) for m in matches
1180+
]
1181+
if len(locations) > 1:
1182+
return Location.callsite(
1183+
locations[-1], locations[-2::-1], context=self._c
1184+
)
1185+
elif len(locations) == 1:
1186+
return locations[0]
11811187
return Location.unknown(context=self._c)
11821188

11831189
def set_symbolic_guards(

test/python/fx_importer/basic_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,33 @@ def forward(self):
205205
"torch-simplification-pipeline",
206206
)
207207
print(m)
208+
209+
210+
@run
211+
# CHECK-LABEL: test_stack_trace
212+
# CHECK: #loc[[LOC1:.+]] = loc(
213+
# CHECK: #loc[[LOC2:.+]] = loc(
214+
# CHECK: #loc[[LOC3:.+]] = loc(
215+
# CHECK: #loc[[LOC4:.+]] = loc(callsite(#loc[[LOC2]] at #loc[[LOC3]]))
216+
# CHECK: #loc[[LOC5:.+]] = loc(callsite(#loc[[LOC1]] at #loc[[LOC4]]))
217+
# CHECK: %{{.+}} = torch.aten.add.Tensor {{.+}} loc(#loc[[LOC4]])
218+
def test_stack_trace():
219+
class Basic(nn.Module):
220+
def __init__(self):
221+
super().__init__()
222+
223+
def forward(self, x, y):
224+
def bar(x):
225+
return x + 1.0
226+
227+
def foo(x, y):
228+
return bar(x) + bar(y)
229+
230+
z = foo(x, y)
231+
return {"z": z}
232+
233+
x = torch.randn(128, 128)
234+
y = torch.randn(128, 128)
235+
m = fx.export_and_import(Basic(), x, y, func_name="test_stack_trace")
236+
mlir_asm = m.operation.get_asm(enable_debug_info=True)
237+
print(mlir_asm)

0 commit comments

Comments
 (0)