Skip to content

Commit b657061

Browse files
zhxchen17pytorchmergebot
authored andcommitted
[precompile] Integrate AOTI as a backend. (pytorch#167338)
Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#167338 Approved by: https://github.com/jamesjwu
1 parent 226850c commit b657061

File tree

8 files changed

+225
-11
lines changed

8 files changed

+225
-11
lines changed

test/dynamo/test_aot_compile.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# Owner(s): ["module: dynamo"]
22

3+
import copy
34
import functools
45
import inspect
56
import os
67
import pickle
8+
import unittest
79
from contextlib import contextmanager
810
from unittest.mock import patch
911

@@ -13,13 +15,16 @@
1315
import torch._inductor.test_case
1416
import torch.onnx.operators
1517
import torch.utils.cpp_extension
16-
from torch._dynamo.aot_compile import ModelInput, SerializableCallable
18+
from torch._dynamo.aot_compile import AOTCompiledModel, ModelInput, SerializableCallable
1719
from torch._dynamo.exc import PackageError, Unsupported
1820
from torch._dynamo.package import DynamoCache
1921
from torch._dynamo.precompile_context import PrecompileContext
2022
from torch._inductor.runtime.runtime_utils import cache_dir
2123
from torch.fx._graph_pickler import GraphPickler
22-
from torch.testing._internal.common_utils import instantiate_parametrized_tests
24+
from torch.testing._internal.common_utils import (
25+
instantiate_parametrized_tests,
26+
TEST_CUDA,
27+
)
2328

2429

2530
MY_LAMBDA = lambda x: x + 1 # noqa: E731
@@ -599,6 +604,92 @@ def fn(x, y=1):
599604
actual = compiled_fn(*inputs)
600605
self.assertEqual(expected, actual)
601606

607+
@unittest.skipIf(not TEST_CUDA, "requires cuda")
608+
def test_aot_compile_with_aoti(self):
609+
with torch.device("cuda"):
610+
from torch._dynamo.hooks import Hooks
611+
612+
def fn(x, y):
613+
return x + y
614+
615+
def make_inputs():
616+
return (torch.randn(3, 4), torch.randn(3, 4))
617+
618+
compiled_fn = torch._dynamo.aot_compile.aot_compile_fullgraph(
619+
fn,
620+
(make_inputs(), {}),
621+
Hooks(),
622+
torch._TorchCompileAOTInductorWrapper(None, None, None),
623+
)
624+
625+
test_inputs = make_inputs()
626+
expected = fn(*test_inputs)
627+
actual = compiled_fn(*test_inputs)
628+
self.assertEqual(expected, actual)
629+
compiled_fn.save_compiled_function(self.path())
630+
with open(self.path(), "rb") as f:
631+
compiled_fn = torch.compiler.load_compiled_function(f)
632+
actual = compiled_fn(*test_inputs)
633+
self.assertEqual(expected, actual)
634+
635+
@unittest.skipIf(not TEST_CUDA, "requires cuda")
636+
def test_aot_compile_with_aoti_module(self):
637+
with torch.device("cuda"):
638+
from torch._dynamo.hooks import Hooks
639+
640+
mod = SimpleLinearModule()
641+
642+
def make_inputs():
643+
return (torch.randn(4, 3),)
644+
645+
compiled_mod = torch._dynamo.aot_compile.aot_compile_module(
646+
mod,
647+
[ModelInput(make_inputs(), {}, [])],
648+
Hooks(),
649+
torch._TorchCompileAOTInductorWrapper(None, None, None),
650+
)
651+
652+
def get_grads(m: torch.nn.Module):
653+
return {name: p.grad for name, p in m.named_parameters()}
654+
655+
original_mod = copy.deepcopy(mod)
656+
test_inputs = make_inputs()
657+
expected = mod(*test_inputs)
658+
expected.sum().backward()
659+
expected_grads = get_grads(mod)
660+
661+
actual = compiled_mod(*test_inputs)
662+
self.assertEqual(expected, actual)
663+
serialized = compiled_mod.serialize()
664+
compiled_fn = AOTCompiledModel.deserialize(original_mod, serialized)
665+
actual = compiled_fn(*test_inputs)
666+
actual.sum().backward()
667+
self.assertEqual(get_grads(original_mod), expected_grads)
668+
669+
@unittest.skipIf(not TEST_CUDA, "requires cuda")
670+
def test_aot_compile_with_aoti_torch_compile(self):
671+
with torch.device("cuda"):
672+
673+
def fn(x, y):
674+
return x + y
675+
676+
def make_inputs():
677+
return (torch.randn(3, 4), torch.randn(3, 4))
678+
679+
compiled_fn = torch.compile(
680+
fn, fullgraph=True, options={"use_aoti": True}
681+
).aot_compile((make_inputs(), {}))
682+
test_inputs = make_inputs()
683+
expected = fn(*test_inputs)
684+
actual = compiled_fn(*test_inputs)
685+
self.assertEqual(expected, actual)
686+
compiled_fn.save_compiled_function(self.path())
687+
with open(self.path(), "rb") as f:
688+
compiled_fn = torch.compiler.load_compiled_function(f)
689+
actual = compiled_fn(*test_inputs)
690+
self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor")
691+
self.assertEqual(expected, actual)
692+
602693

603694
if __name__ == "__main__":
604695
from torch._dynamo.test_case import run_tests

torch/__init__.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2439,6 +2439,35 @@ def reset(self):
24392439
reset_cudagraph_trees()
24402440

24412441

2442+
class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper):
2443+
compiler_name = "aotinductor"
2444+
2445+
def __init__(self, mode, options, dynamic):
2446+
super().__init__(mode, options, dynamic)
2447+
self.apply_options({"cpp_wrapper": True})
2448+
self.apply_options({"aot_inductor.package": True})
2449+
2450+
def __call__(self, model_, inputs_):
2451+
from contextlib import nullcontext
2452+
from unittest import mock
2453+
2454+
from torch._guards import detect_fake_mode
2455+
from torch._inductor.virtualized import V
2456+
2457+
fake_mode = detect_fake_mode(inputs_)
2458+
ctx = (
2459+
mock.patch.object(fake_mode, "allow_non_fake_inputs", True)
2460+
if fake_mode
2461+
else nullcontext()
2462+
)
2463+
with (
2464+
V.set_aot_compilation(True),
2465+
ctx,
2466+
torch._inductor.config.patch("enable_autograd_for_aot", True),
2467+
):
2468+
return super().__call__(model_, inputs_)
2469+
2470+
24422471
class _TorchCompileWrapper:
24432472
def __init__(self, backend, mode, options, dynamic):
24442473
from torch._dynamo.backends.registry import lookup_backend
@@ -2672,8 +2701,10 @@ def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]:
26722701
backend = bisect_backend
26732702

26742703
guard_filter_fn = None
2704+
use_aoti = False
26752705
if options and isinstance(options, dict):
26762706
guard_filter_fn = options.pop("guard_filter_fn", None)
2707+
use_aoti = options.pop("use_aoti", False)
26772708

26782709
if torch.compiler.is_exporting():
26792710
warnings.warn(
@@ -2700,7 +2731,10 @@ def export_wrapped_fn(*args, **kwargs):
27002731
return export_wrapped_fn
27012732

27022733
if backend == "inductor":
2703-
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
2734+
if use_aoti:
2735+
backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic)
2736+
else:
2737+
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
27042738
else:
27052739
backend = _TorchCompileWrapper(backend, mode, options, dynamic)
27062740

torch/_dynamo/aot_compile.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class CompileArtifacts:
5353
argdefs: Optional[tuple[Any, ...]]
5454
source_info: "SourceInfo"
5555
device_type: str
56+
backend_name: str
5657
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
5758

5859
def check_compatibility(self) -> None:
@@ -166,7 +167,8 @@ def deserialize(cls, data: bytes) -> "AOTCompiledFunction":
166167
state = pickle.loads(data)
167168
state["bytecode"] = SerializedCode.to_code_object(state["bytecode"])
168169
deserializer, compiled_fn_state = state["compiled_fn"]
169-
state["compiled_fn"] = deserializer(compiled_fn_state)
170+
with torch._inductor.config.patch(enable_autograd_for_aot=True):
171+
state["compiled_fn"] = deserializer(compiled_fn_state)
170172
state["original_code"] = SerializedCode.to_code_object(state["original_code"])
171173

172174
artifacts = CompileArtifacts(**state)
@@ -273,6 +275,7 @@ def new_guard_filter_fn(
273275
argdefs=fn.__defaults__,
274276
source_info=source_info,
275277
device_type=device_type,
278+
backend_name=getattr(backend, "compiler_name", "unknown"),
276279
)
277280
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)
278281

torch/_functorch/_aot_autograd/aot_autograd_result.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ def wrap_post_compile(
511511
).post_compile(
512512
compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata
513513
)
514+
compiled_fw_func._boxed_call = True
514515
disable_amp = torch._C._is_any_autocast_enabled()
515516

516517
if needs_autograd:

torch/_inductor/compile_fx.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,7 +1640,9 @@ def codegen_and_compile(
16401640
# pyrefly: ignore [unbound-name]
16411641
(str, list, torch.fx.GraphModule),
16421642
), type(compiled_fn)
1643-
return CompiledAOTI(compiled_fn)
1643+
return CompiledAOTI(
1644+
filename=compiled_fn, device_type=graph.device_type
1645+
)
16441646

16451647
# TODO: Hoist this above V.aot_compilation
16461648
# pyrefly: ignore [unbound-name]
@@ -2713,7 +2715,7 @@ def bw_compiler(
27132715
or torch._guards.TracingContext(fake_mode)
27142716
)
27152717

2716-
if V.aot_compilation:
2718+
if V.aot_compilation and not config.enable_autograd_for_aot:
27172719
from .utils import is_valid_aoti_model_name
27182720

27192721
is_valid_aoti_model_name()

torch/_inductor/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,6 +1190,8 @@ def decide_compile_threads() -> int:
11901190

11911191
file_lock_timeout: int = int(os.environ.get("TORCHINDUCTOR_FILE_LOCK_TIMEOUT", "600"))
11921192

1193+
enable_autograd_for_aot: bool = False
1194+
11931195

11941196
def get_worker_log_path() -> Optional[str]:
11951197
log_loc = None

torch/_inductor/output_code.py

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -773,20 +773,95 @@ class CompiledAOTI(OutputCode):
773773
"""
774774

775775
filename: Union[str, list[Union[str, Weights]], torch.fx.GraphModule]
776+
device_type: str
777+
current_callable: Optional[Callable[..., Any]] = None
778+
_cached_files: dict[str, bytes] = dataclasses.field(default_factory=dict)
779+
780+
def __post_init__(self):
781+
if not config.aot_inductor.link_libtorch:
782+
return
783+
784+
if (
785+
torch._inductor.cpp_builder._IS_MACOS
786+
or torch._inductor.cpp_builder._IS_WINDOWS
787+
):
788+
return
789+
790+
if config.aot_inductor.cross_target_platform == "windows":
791+
return
792+
793+
if config.aot_inductor.package_cpp_only:
794+
return
795+
796+
if not config.enable_autograd_for_aot:
797+
return
798+
799+
if isinstance(self.filename, list):
800+
current_callable = next(
801+
fn for fn in self.filename if isinstance(fn, str) and fn.endswith(".so")
802+
)
803+
else:
804+
current_callable = self.filename
805+
806+
if isinstance(current_callable, torch.fx.GraphModule):
807+
self.current_callable = current_callable
808+
return
809+
810+
if self.device_type.startswith("cuda"):
811+
current_callable = (
812+
torch._C._aoti.AOTIModelContainerRunnerCuda( # type: ignore[call-arg]
813+
current_callable,
814+
1,
815+
self.device_type,
816+
"",
817+
True,
818+
).run # type: ignore[attr-defined]
819+
) # type: ignore[attr-defined]
820+
elif self.device_type == "cpu":
821+
current_callable = (
822+
torch._C._aoti.AOTIModelContainerRunnerCpu( # type: ignore[call-arg]
823+
current_callable, 1
824+
).run # type: ignore[attr-defined]
825+
) # type: ignore[attr-defined]
826+
else:
827+
raise RuntimeError(f"unsupported device type {self.device_type}")
828+
self.current_callable = current_callable
829+
self._boxed_call = True
830+
for file in self._cached_files:
831+
if not os.path.exists(file):
832+
with open(file, "wb") as f:
833+
f.write(self._cached_files[file])
776834

777835
def __call__(self, inputs: Sequence[Any]) -> Any:
778-
raise NotImplementedError("NYI")
836+
if self.current_callable is None:
837+
raise RuntimeError("AOTInductor compiled so is not loaded")
838+
return self.current_callable(inputs)
839+
840+
def prepare_for_serialization(self) -> None:
841+
self.current_callable = None
842+
self._cached_files = {}
843+
filenames: list[str] = []
844+
if isinstance(self.filename, list):
845+
filenames = self.filename # type: ignore[assignment]
846+
elif isinstance(self.filename, str):
847+
filenames = [self.filename]
848+
for name in filenames:
849+
with open(name, "rb") as f:
850+
self._cached_files[name] = f.read()
851+
852+
def __getstate__(self):
853+
state = self.__dict__.copy()
854+
state["current_callable"] = None
855+
return state
779856

780857
def post_compile(
781858
self,
782859
example_inputs: Sequence[InputType],
783860
constants: CompiledFxGraphConstants,
784861
graph_kwargs: _CompileFxKwargs,
785862
) -> None:
786-
pass
787-
788-
def prepare_for_serialization(self) -> None:
789-
pass
863+
if self.current_callable is None:
864+
self.__post_init__()
790865

791866
def set_triton_bundle(self, triton_bundle: Any) -> None:
792867
pass

torch/csrc/inductor/aoti_runner/pybind.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ void initAOTIRunnerBindings(PyObject* module) {
6666
int,
6767
const std::string&,
6868
const std::string&>())
69+
.def(py::init<
70+
const std::string&,
71+
int,
72+
const std::string&,
73+
const std::string&,
74+
const bool>())
6975
.def(
7076
"run",
7177
&AOTIModelContainerRunnerCuda::run,

0 commit comments

Comments
 (0)