|
14 | 14 | from pathlib import Path
|
15 | 15 | from typing import NamedTuple
|
16 | 16 |
|
| 17 | +from torch.testing._internal.inductor_utils import HAS_GPU |
| 18 | + |
| 19 | + |
| 20 | +if HAS_GPU: |
| 21 | + import triton |
| 22 | + import triton.language as tl |
| 23 | + |
| 24 | + from torch.library import wrap_triton |
| 25 | + from torch.utils._triton import has_triton |
| 26 | + |
17 | 27 | import torch
|
18 | 28 | import torch._dynamo as torchdynamo
|
19 | 29 | import torch._export.serde.schema as schema
|
20 | 30 | import torch.export._trace
|
21 | 31 | import torch.utils._pytree as pytree
|
22 | 32 | from torch._export.db.case import ExportCase, SupportLevel
|
23 | 33 | from torch._export.db.examples import all_examples
|
| 34 | +from torch._export.serde.schema import ArgumentKind |
24 | 35 | from torch._export.serde.serialize import (
|
25 | 36 | _dict_to_dataclass,
|
26 | 37 | _to_json_bytes,
|
@@ -582,6 +593,118 @@ def forward(self, x):
|
582 | 593 | serialized.exported_program.range_constraints[symint.name].max_val, 3
|
583 | 594 | )
|
584 | 595 |
|
| 596 | + @unittest.skipIf( |
| 597 | + not torch.cuda.is_available() or not has_triton(), "requires cuda and triton" |
| 598 | + ) |
| 599 | + def test_triton_hop(self) -> None: |
| 600 | + @triton.jit |
| 601 | + def add_kernel( |
| 602 | + in_ptr0, |
| 603 | + in_ptr1, |
| 604 | + out_ptr, |
| 605 | + n_elements, |
| 606 | + BLOCK_SIZE: "tl.constexpr", |
| 607 | + ): |
| 608 | + pid = tl.program_id(axis=0) |
| 609 | + block_start = pid * BLOCK_SIZE |
| 610 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 611 | + mask = offsets < n_elements |
| 612 | + x = tl.load(in_ptr0 + offsets, mask=mask) |
| 613 | + y = tl.load(in_ptr1 + offsets, mask=mask) |
| 614 | + output = x + y |
| 615 | + tl.store(out_ptr + offsets, output, mask=mask) |
| 616 | + |
| 617 | + def custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 618 | + output = torch.empty_like(x) |
| 619 | + n_elements = output.numel() |
| 620 | + |
| 621 | + def grid(meta): |
| 622 | + return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| 623 | + |
| 624 | + wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16) |
| 625 | + |
| 626 | + return output |
| 627 | + |
| 628 | + class MyModel(torch.nn.Module): |
| 629 | + def forward(self, x, y): |
| 630 | + return custom_add(x, y) |
| 631 | + |
| 632 | + def custom_add_autotune(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 633 | + output = torch.empty_like(x) |
| 634 | + n_elements = output.numel() |
| 635 | + |
| 636 | + def grid(meta): |
| 637 | + return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| 638 | + |
| 639 | + wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16, num_warps=8) |
| 640 | + |
| 641 | + return output |
| 642 | + |
| 643 | + class MyModelAutotune(torch.nn.Module): |
| 644 | + def forward(self, x, y): |
| 645 | + return custom_add_autotune(x, y) |
| 646 | + |
| 647 | + device = "cuda" |
| 648 | + |
| 649 | + for m in [MyModel().to(device), MyModelAutotune().to(device)]: |
| 650 | + args = (torch.randn(3, device=device), torch.randn(3, device=device)) |
| 651 | + ep = torch.export.export(m, args=args) |
| 652 | + ep = ep.run_decompositions(decompose_custom_triton_ops=False) |
| 653 | + assert torch.allclose(m(*args), ep.module()(*args)) |
| 654 | + |
| 655 | + serialized = ExportedProgramSerializer().serialize(ep) |
| 656 | + |
| 657 | + for node in serialized.exported_program.graph_module.graph.nodes: |
| 658 | + if ( |
| 659 | + node.target |
| 660 | + == "torch.ops.higher_order.triton_kernel_wrapper_functional" |
| 661 | + ): |
| 662 | + triton_node = node |
| 663 | + |
| 664 | + self.assertIsNotNone(triton_node) |
| 665 | + |
| 666 | + args = [] |
| 667 | + kwargs = [] |
| 668 | + |
| 669 | + for arg in triton_node.inputs: |
| 670 | + if arg.kind == ArgumentKind.POSITIONAL: |
| 671 | + args.append(arg.arg) |
| 672 | + elif arg.kind == ArgumentKind.KEYWORD: |
| 673 | + kwargs.append(arg.arg) |
| 674 | + |
| 675 | + self.assertEqual(len(args), 4) |
| 676 | + self.assertEqual(len(kwargs), 4) |
| 677 | + |
| 678 | + for i in range(3): |
| 679 | + self.assertIsNotNone(args[i].as_tensor) |
| 680 | + |
| 681 | + self.assertEqual(args[3].as_int, 3) |
| 682 | + |
| 683 | + self.assertEqual(kwargs[0].as_string, "add_kernel") # name |
| 684 | + self.assertEqual(kwargs[1].as_ints, [1, 1, 1]) # grid |
| 685 | + self.assertEqual(kwargs[2].as_ints, [2]) # output indices |
| 686 | + self.assertEqual( |
| 687 | + kwargs[3].as_int, 8 if isinstance(m, MyModelAutotune) else 4 |
| 688 | + ) # num warps |
| 689 | + |
| 690 | + self.assertEqual(len(triton_node.outputs), 1) |
| 691 | + self.assertIsNotNone(triton_node.outputs[0].as_tensors) |
| 692 | + self.assertEqual( |
| 693 | + len(triton_node.outputs[0].as_tensors), len(kwargs[2].as_ints) |
| 694 | + ) |
| 695 | + self.assertEqual(triton_node.outputs[0].as_tensors[0].name, "getitem") |
| 696 | + |
| 697 | + with self.assertRaisesRegex( |
| 698 | + SerializeError, |
| 699 | + "deserialize nyi for torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional", |
| 700 | + ): |
| 701 | + ExportedProgramDeserializer().deserialize( |
| 702 | + serialized.exported_program, |
| 703 | + serialized.state_dict, |
| 704 | + serialized.constants, |
| 705 | + serialized.example_inputs, |
| 706 | + ) |
| 707 | + |
585 | 708 | def test_kwargs_default(self) -> None:
|
586 | 709 | """
|
587 | 710 | Tests that the kwargs default values are serialized even if they are not
|
|
0 commit comments