Skip to content

Commit affd071

Browse files
dolpmpytorchmergebot
authored andcommitted
[export] serialization support for triton_kernel_wrapper_functional (pytorch#161314)
Summary: att Test Plan: buck2 test mode/opt //caffe2/test:test_export -- test_triton_hop Rollback Plan: Differential Revision: D80827767 Pull Request resolved: pytorch#161314 Approved by: https://github.com/angelayi
1 parent dac062f commit affd071

File tree

2 files changed

+219
-1
lines changed

2 files changed

+219
-1
lines changed

test/export/test_serialize.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,24 @@
1414
from pathlib import Path
1515
from typing import NamedTuple
1616

17+
from torch.testing._internal.inductor_utils import HAS_GPU
18+
19+
20+
if HAS_GPU:
21+
import triton
22+
import triton.language as tl
23+
24+
from torch.library import wrap_triton
25+
from torch.utils._triton import has_triton
26+
1727
import torch
1828
import torch._dynamo as torchdynamo
1929
import torch._export.serde.schema as schema
2030
import torch.export._trace
2131
import torch.utils._pytree as pytree
2232
from torch._export.db.case import ExportCase, SupportLevel
2333
from torch._export.db.examples import all_examples
34+
from torch._export.serde.schema import ArgumentKind
2435
from torch._export.serde.serialize import (
2536
_dict_to_dataclass,
2637
_to_json_bytes,
@@ -582,6 +593,118 @@ def forward(self, x):
582593
serialized.exported_program.range_constraints[symint.name].max_val, 3
583594
)
584595

596+
@unittest.skipIf(
597+
not torch.cuda.is_available() or not has_triton(), "requires cuda and triton"
598+
)
599+
def test_triton_hop(self) -> None:
600+
@triton.jit
601+
def add_kernel(
602+
in_ptr0,
603+
in_ptr1,
604+
out_ptr,
605+
n_elements,
606+
BLOCK_SIZE: "tl.constexpr",
607+
):
608+
pid = tl.program_id(axis=0)
609+
block_start = pid * BLOCK_SIZE
610+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
611+
mask = offsets < n_elements
612+
x = tl.load(in_ptr0 + offsets, mask=mask)
613+
y = tl.load(in_ptr1 + offsets, mask=mask)
614+
output = x + y
615+
tl.store(out_ptr + offsets, output, mask=mask)
616+
617+
def custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
618+
output = torch.empty_like(x)
619+
n_elements = output.numel()
620+
621+
def grid(meta):
622+
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
623+
624+
wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
625+
626+
return output
627+
628+
class MyModel(torch.nn.Module):
629+
def forward(self, x, y):
630+
return custom_add(x, y)
631+
632+
def custom_add_autotune(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
633+
output = torch.empty_like(x)
634+
n_elements = output.numel()
635+
636+
def grid(meta):
637+
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
638+
639+
wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16, num_warps=8)
640+
641+
return output
642+
643+
class MyModelAutotune(torch.nn.Module):
644+
def forward(self, x, y):
645+
return custom_add_autotune(x, y)
646+
647+
device = "cuda"
648+
649+
for m in [MyModel().to(device), MyModelAutotune().to(device)]:
650+
args = (torch.randn(3, device=device), torch.randn(3, device=device))
651+
ep = torch.export.export(m, args=args)
652+
ep = ep.run_decompositions(decompose_custom_triton_ops=False)
653+
assert torch.allclose(m(*args), ep.module()(*args))
654+
655+
serialized = ExportedProgramSerializer().serialize(ep)
656+
657+
for node in serialized.exported_program.graph_module.graph.nodes:
658+
if (
659+
node.target
660+
== "torch.ops.higher_order.triton_kernel_wrapper_functional"
661+
):
662+
triton_node = node
663+
664+
self.assertIsNotNone(triton_node)
665+
666+
args = []
667+
kwargs = []
668+
669+
for arg in triton_node.inputs:
670+
if arg.kind == ArgumentKind.POSITIONAL:
671+
args.append(arg.arg)
672+
elif arg.kind == ArgumentKind.KEYWORD:
673+
kwargs.append(arg.arg)
674+
675+
self.assertEqual(len(args), 4)
676+
self.assertEqual(len(kwargs), 4)
677+
678+
for i in range(3):
679+
self.assertIsNotNone(args[i].as_tensor)
680+
681+
self.assertEqual(args[3].as_int, 3)
682+
683+
self.assertEqual(kwargs[0].as_string, "add_kernel") # name
684+
self.assertEqual(kwargs[1].as_ints, [1, 1, 1]) # grid
685+
self.assertEqual(kwargs[2].as_ints, [2]) # output indices
686+
self.assertEqual(
687+
kwargs[3].as_int, 8 if isinstance(m, MyModelAutotune) else 4
688+
) # num warps
689+
690+
self.assertEqual(len(triton_node.outputs), 1)
691+
self.assertIsNotNone(triton_node.outputs[0].as_tensors)
692+
self.assertEqual(
693+
len(triton_node.outputs[0].as_tensors), len(kwargs[2].as_ints)
694+
)
695+
self.assertEqual(triton_node.outputs[0].as_tensors[0].name, "getitem")
696+
697+
with self.assertRaisesRegex(
698+
SerializeError,
699+
"deserialize nyi for torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional",
700+
):
701+
ExportedProgramDeserializer().deserialize(
702+
serialized.exported_program,
703+
serialized.state_dict,
704+
serialized.constants,
705+
serialized.example_inputs,
706+
)
707+
585708
def test_kwargs_default(self) -> None:
586709
"""
587710
Tests that the kwargs default values are serialized even if they are not

torch/_export/serde/serialize.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from torch.utils._sympy.symbol import prefix_str, SymT
3636
from torch.utils._sympy.value_ranges import ValueRanges
3737
from torch.utils._traceback import CapturedTraceback
38+
from torch.utils._triton import has_triton
3839

3940
from ..utils import remove_proxy_from_state_dict
4041
from .schema import ( # type: ignore[attr-defined]
@@ -93,6 +94,14 @@
9394
from .union import _Union
9495

9596

97+
if has_triton():
98+
from triton.runtime.autotuner import Autotuner
99+
else:
100+
101+
class Autotuner: # type: ignore[no-redef]
102+
pass
103+
104+
96105
__all__ = [
97106
"serialize",
98107
"GraphModuleSerializer",
@@ -670,6 +679,75 @@ def serialize_tensor_list_output(node):
670679
metadata=self.serialize_metadata(node),
671680
is_hop_single_tensor_return=False,
672681
)
682+
elif (
683+
node.target
684+
is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional
685+
):
686+
assert has_triton(), "triton required to serialize triton kernels"
687+
688+
meta_val = node.meta["val"]
689+
assert isinstance(meta_val, dict)
690+
691+
output_keys = meta_val.keys()
692+
output_indices = []
693+
694+
assert isinstance(node.kwargs["kernel_idx"], int)
695+
kernel = torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.get_kernel(
696+
node.kwargs["kernel_idx"]
697+
)
698+
699+
if isinstance(kernel, Autotuner):
700+
assert len(kernel.configs) == 1
701+
num_warps = kernel.configs[0].num_warps
702+
assert kernel.configs[0].num_ctas == 1, (
703+
"serialization only supports num_ctas == 1"
704+
)
705+
kernel = kernel.fn
706+
else:
707+
num_warps = 4
708+
709+
constexpr_keys = set()
710+
for p in kernel.params:
711+
if p.is_constexpr:
712+
constexpr_keys.add(p.name)
713+
714+
found_constexpr = False
715+
args_new = ()
716+
i = 0
717+
718+
assert isinstance(node.kwargs["kwargs"], dict)
719+
for k, v in node.kwargs["kwargs"].items():
720+
# don't serialize constexpr since they will
721+
# be embedded into the binary and don't
722+
# need to be passed around as attributes
723+
if k in constexpr_keys:
724+
found_constexpr = True
725+
continue
726+
727+
assert not found_constexpr, (
728+
"non-constexpr args found after constexpr arg(s)"
729+
)
730+
731+
if k in output_keys:
732+
output_indices.append(i)
733+
args_new += (v,) # type: ignore[assignment]
734+
i += 1
735+
736+
assert isinstance(node.kwargs["grid"], list)
737+
kwargs_new = {
738+
"name": kernel.fn.__name__,
739+
"grid": node.kwargs["grid"][0],
740+
"output_indices": output_indices,
741+
"num_warps": num_warps,
742+
}
743+
744+
ex_node = Node(
745+
target=self.serialize_operator(node.target),
746+
inputs=self.serialize_hoo_inputs(args_new, kwargs_new),
747+
outputs=self.serialize_hoo_outputs(node),
748+
metadata=self.serialize_metadata(node),
749+
is_hop_single_tensor_return=_is_hop_single_tensor_return(node),
750+
)
673751
else:
674752
ex_node = Node(
675753
target=self.serialize_operator(node.target),
@@ -1541,6 +1619,17 @@ def serialize_hoo_outputs(self, node: torch.fx.Node) -> list[Argument]:
15411619
outputs.append(self.serialize_output(name, element_meta_val))
15421620

15431621
return outputs
1622+
elif isinstance(meta_val, dict):
1623+
tensor_args = []
1624+
# use the dict key as the idx
1625+
for idx, meta in meta_val.items():
1626+
if not isinstance(meta, torch.Tensor):
1627+
raise SerializeError(
1628+
f"Serialize list output with type {type(meta)} nyi"
1629+
)
1630+
name = self._output_node_name_at_index(node, idx)
1631+
tensor_args.append(self.serialize_tensor_output(name, meta))
1632+
return [Argument.create(as_tensors=tensor_args)]
15441633
else:
15451634
return [self.serialize_output(node.name, meta_val)]
15461635

@@ -2067,7 +2156,13 @@ def _is_single_tensor_return(target) -> bool:
20672156

20682157
fx_node = self.graph.create_node("call_function", target, args, {}, name)
20692158
self.deserialize_sym_op_outputs(serialized_node, fx_node)
2070-
2159+
elif (
2160+
target
2161+
is torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional
2162+
):
2163+
raise SerializeError(
2164+
"deserialize nyi for torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional"
2165+
)
20712166
elif isinstance(target, torch._ops.HigherOrderOperator):
20722167
args, kwargs = self.deserialize_hoo_inputs(serialized_node.inputs)
20732168
metadata = self.deserialize_metadata(serialized_node.metadata)

0 commit comments

Comments
 (0)