|
1 | 1 | # Owner(s): ["module: dynamo"] |
2 | 2 |
|
| 3 | +import copy |
3 | 4 | import functools |
4 | 5 | import inspect |
5 | 6 | import os |
6 | 7 | import pickle |
| 8 | +import unittest |
7 | 9 | from contextlib import contextmanager |
8 | 10 | from unittest.mock import patch |
9 | 11 |
|
|
13 | 15 | import torch._inductor.test_case |
14 | 16 | import torch.onnx.operators |
15 | 17 | import torch.utils.cpp_extension |
16 | | -from torch._dynamo.aot_compile import ModelInput, SerializableCallable |
| 18 | +from torch._dynamo.aot_compile import AOTCompiledModel, ModelInput, SerializableCallable |
17 | 19 | from torch._dynamo.exc import PackageError, Unsupported |
18 | 20 | from torch._dynamo.package import DynamoCache |
19 | 21 | from torch._dynamo.precompile_context import PrecompileContext |
20 | 22 | from torch._inductor.runtime.runtime_utils import cache_dir |
21 | 23 | from torch.fx._graph_pickler import GraphPickler |
22 | | -from torch.testing._internal.common_utils import instantiate_parametrized_tests |
| 24 | +from torch.testing._internal.common_utils import ( |
| 25 | + instantiate_parametrized_tests, |
| 26 | + TEST_CUDA, |
| 27 | +) |
23 | 28 |
|
24 | 29 |
|
25 | 30 | MY_LAMBDA = lambda x: x + 1 # noqa: E731 |
@@ -599,6 +604,92 @@ def fn(x, y=1): |
599 | 604 | actual = compiled_fn(*inputs) |
600 | 605 | self.assertEqual(expected, actual) |
601 | 606 |
|
| 607 | + @unittest.skipIf(not TEST_CUDA, "requires cuda") |
| 608 | + def test_aot_compile_with_aoti(self): |
| 609 | + with torch.device("cuda"): |
| 610 | + from torch._dynamo.hooks import Hooks |
| 611 | + |
| 612 | + def fn(x, y): |
| 613 | + return x + y |
| 614 | + |
| 615 | + def make_inputs(): |
| 616 | + return (torch.randn(3, 4), torch.randn(3, 4)) |
| 617 | + |
| 618 | + compiled_fn = torch._dynamo.aot_compile.aot_compile_fullgraph( |
| 619 | + fn, |
| 620 | + (make_inputs(), {}), |
| 621 | + Hooks(), |
| 622 | + torch._TorchCompileAOTInductorWrapper(None, None, None), |
| 623 | + ) |
| 624 | + |
| 625 | + test_inputs = make_inputs() |
| 626 | + expected = fn(*test_inputs) |
| 627 | + actual = compiled_fn(*test_inputs) |
| 628 | + self.assertEqual(expected, actual) |
| 629 | + compiled_fn.save_compiled_function(self.path()) |
| 630 | + with open(self.path(), "rb") as f: |
| 631 | + compiled_fn = torch.compiler.load_compiled_function(f) |
| 632 | + actual = compiled_fn(*test_inputs) |
| 633 | + self.assertEqual(expected, actual) |
| 634 | + |
| 635 | + @unittest.skipIf(not TEST_CUDA, "requires cuda") |
| 636 | + def test_aot_compile_with_aoti_module(self): |
| 637 | + with torch.device("cuda"): |
| 638 | + from torch._dynamo.hooks import Hooks |
| 639 | + |
| 640 | + mod = SimpleLinearModule() |
| 641 | + |
| 642 | + def make_inputs(): |
| 643 | + return (torch.randn(4, 3),) |
| 644 | + |
| 645 | + compiled_mod = torch._dynamo.aot_compile.aot_compile_module( |
| 646 | + mod, |
| 647 | + [ModelInput(make_inputs(), {}, [])], |
| 648 | + Hooks(), |
| 649 | + torch._TorchCompileAOTInductorWrapper(None, None, None), |
| 650 | + ) |
| 651 | + |
| 652 | + def get_grads(m: torch.nn.Module): |
| 653 | + return {name: p.grad for name, p in m.named_parameters()} |
| 654 | + |
| 655 | + original_mod = copy.deepcopy(mod) |
| 656 | + test_inputs = make_inputs() |
| 657 | + expected = mod(*test_inputs) |
| 658 | + expected.sum().backward() |
| 659 | + expected_grads = get_grads(mod) |
| 660 | + |
| 661 | + actual = compiled_mod(*test_inputs) |
| 662 | + self.assertEqual(expected, actual) |
| 663 | + serialized = compiled_mod.serialize() |
| 664 | + compiled_fn = AOTCompiledModel.deserialize(original_mod, serialized) |
| 665 | + actual = compiled_fn(*test_inputs) |
| 666 | + actual.sum().backward() |
| 667 | + self.assertEqual(get_grads(original_mod), expected_grads) |
| 668 | + |
| 669 | + @unittest.skipIf(not TEST_CUDA, "requires cuda") |
| 670 | + def test_aot_compile_with_aoti_torch_compile(self): |
| 671 | + with torch.device("cuda"): |
| 672 | + |
| 673 | + def fn(x, y): |
| 674 | + return x + y |
| 675 | + |
| 676 | + def make_inputs(): |
| 677 | + return (torch.randn(3, 4), torch.randn(3, 4)) |
| 678 | + |
| 679 | + compiled_fn = torch.compile( |
| 680 | + fn, fullgraph=True, options={"use_aoti": True} |
| 681 | + ).aot_compile((make_inputs(), {})) |
| 682 | + test_inputs = make_inputs() |
| 683 | + expected = fn(*test_inputs) |
| 684 | + actual = compiled_fn(*test_inputs) |
| 685 | + self.assertEqual(expected, actual) |
| 686 | + compiled_fn.save_compiled_function(self.path()) |
| 687 | + with open(self.path(), "rb") as f: |
| 688 | + compiled_fn = torch.compiler.load_compiled_function(f) |
| 689 | + actual = compiled_fn(*test_inputs) |
| 690 | + self.assertEqual(compiled_fn._artifacts.backend_name, "aotinductor") |
| 691 | + self.assertEqual(expected, actual) |
| 692 | + |
602 | 693 |
|
603 | 694 | if __name__ == "__main__": |
604 | 695 | from torch._dynamo.test_case import run_tests |
|
0 commit comments