Skip to content

Commit 6b289f2

Browse files
authored
[FxImporter] Added FxImporter test method to be executed via torch.co… (llvm#3795)
1 parent 45bb17e commit 6b289f2

File tree

1 file changed

+78
-2
lines changed

1 file changed

+78
-2
lines changed

projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
import torch.utils._pytree as pytree
99
from torch.export.graph_signature import OutputSpec, OutputKind
1010
from torch.export import ExportedProgram
11+
from torch._dynamo.backends.common import aot_autograd
1112

1213
from torch_mlir import fx
1314
from torch_mlir_e2e_test.configs.utils import (
1415
recursively_convert_to_numpy,
1516
recursively_convert_from_numpy,
1617
)
1718
from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem
19+
from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME
1820

1921

2022
def refine_result_type(_result):
@@ -31,17 +33,91 @@ def refine_result_type(_result):
3133
class FxImporterTestConfig(TestConfig):
3234
"""TestConfig that runs the torch.nn.Module with Fx Importer"""
3335

34-
def __init__(self, backend, output_type="linalg-on-tensors"):
36+
def __init__(self, backend, output_type="linalg-on-tensors", torch_compile=False):
3537
super().__init__()
3638
self._backend = backend
39+
self._torch_compile = torch_compile
3740
self._output_type = output_type
3841

3942
def compile(
4043
self, program: torch.nn.Module, verbose: bool = False
4144
) -> torch.nn.Module:
4245
return program
4346

44-
def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
47+
def run(self, artifact: torch.nn.Module, trace: Trace):
48+
return (
49+
self._export_run(artifact, trace)
50+
if not self._torch_compile
51+
else self._stateless_run(artifact, trace)
52+
)
53+
54+
def _stateless_run(self, artifact: torch.nn.Module, trace: Trace):
55+
dynamic_argument_pos = None
56+
dynamic_dim_pos = None
57+
annotations = getattr(artifact.forward, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME)
58+
for i, annotation in enumerate(annotations):
59+
if i == 0: # Skip the "self" annotation.
60+
continue
61+
if not annotation[2]:
62+
raise ValueError(
63+
"Can only compile inputs annotated as having value semantics."
64+
)
65+
for dim_i, dim in enumerate(annotation[0]):
66+
if dim == -1:
67+
dynamic_argument_pos = i - 1
68+
dynamic_dim_pos = dim_i
69+
break
70+
if dynamic_argument_pos is not None:
71+
break
72+
result: Trace = []
73+
for item in trace:
74+
75+
def _base_backend(gm: torch.fx.GraphModule, example_inputs):
76+
for node in gm.graph.nodes:
77+
if node.op == "placeholder":
78+
if (
79+
isinstance(node.meta["val"], torch.SymInt)
80+
and not node.users
81+
):
82+
gm.graph.erase_node(node)
83+
module = fx.stateless_fx_import(
84+
gm,
85+
output_type=self._output_type,
86+
model_name=artifact.__class__.__name__,
87+
)
88+
module = self._backend.compile(module)
89+
backend_module = self._backend.load(module)
90+
91+
def invoke_func(*torch_inputs):
92+
torch_inputs = [
93+
x
94+
for x in filter(
95+
lambda i: isinstance(i, torch.Tensor), torch_inputs
96+
)
97+
]
98+
with torch.no_grad():
99+
numpy_inputs = recursively_convert_to_numpy(torch_inputs)
100+
return recursively_convert_from_numpy(
101+
getattr(backend_module, artifact.__class__.__name__)(
102+
*numpy_inputs
103+
)
104+
)
105+
106+
return invoke_func
107+
108+
fw_compiler = aot_autograd(fw_compiler=_base_backend)
109+
if dynamic_argument_pos is not None:
110+
torch._dynamo.mark_dynamic(
111+
item.inputs[dynamic_argument_pos], dynamic_dim_pos
112+
)
113+
module = torch.compile(artifact, backend=fw_compiler)
114+
outputs = module(*item.inputs)
115+
result.append(
116+
TraceItem(symbol=item.symbol, inputs=item.inputs, output=outputs)
117+
)
118+
return result
119+
120+
def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
45121
result: Trace = []
46122
for item in trace:
47123
prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs))

0 commit comments

Comments
 (0)