Skip to content

Commit e85b3ee

Browse files
authored
feat: _register_custom_op supports List[torch.Tensor] (#2529)
1 parent 0d6804c commit e85b3ee

File tree

3 files changed

+200
-4
lines changed

3 files changed

+200
-4
lines changed
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
from typing import TYPE_CHECKING
2+
3+
from lightning_utilities.core.imports import package_available
4+
import numpy as np
5+
import pytest
6+
import torch
7+
import torch.nn as nn
8+
from torch._library.custom_ops import CustomOpDef
9+
10+
import thunder
11+
from thunder.core import dtypes
12+
from thunder.core import devices
13+
from thunder.torch.custom_op import _register_custom_op
14+
from thunder.executors.custom_op_ex import custom_op_ex
15+
from thunder.tests.framework import TorchExecutor
16+
from thunder.tests.framework import instantiate
17+
18+
if TYPE_CHECKING:
19+
from thunder.core.symbol import BoundSymbol
20+
21+
22+
@torch.library.custom_op("my_custom_op::list_mul", mutates_args=())
23+
def list_mul(tensors: list[torch.Tensor], c: float | None = None, d: str = "") -> list[torch.Tensor]:
24+
if len(tensors) != 2:
25+
raise ValueError("The list of tensors must contain exactly two elements for this operation.")
26+
return [tensors[0] * tensors[1]]
27+
28+
29+
@torch.library.register_kernel("my_custom_op::list_mul", "cpu")
30+
def _(tensors: list[torch.Tensor], c: float | None = None, d: str = "") -> list[torch.Tensor]:
31+
return [
32+
torch.from_numpy(
33+
np.multiply(
34+
tensors[0].numpy(force=True),
35+
tensors[1].numpy(force=True),
36+
)
37+
)
38+
]
39+
40+
41+
@torch.library.register_kernel("my_custom_op::list_mul", "cuda")
42+
def _(tensors: list[torch.Tensor], c: float | None = None, d: str = "") -> list[torch.Tensor]:
43+
return [tensors[0] * tensors[1]]
44+
45+
46+
@torch.library.register_fake("my_custom_op::list_mul")
47+
def _(tensors: list[torch.Tensor], c: float | None = None, d: str = "") -> list[torch.Tensor]:
48+
return [torch.empty_like(tensors[0])]
49+
50+
51+
def setup_context_for_my_custom_op_list_mul(ctx, inputs, output) -> None:
52+
tensors_list, *_ = inputs
53+
ctx.save_for_backward(tensors_list[0], tensors_list[1])
54+
55+
56+
def backward_of_my_custom_op_list_mul(ctx, grad) -> tuple[list[torch.Tensor], None, None]:
57+
a, b = ctx.saved_tensors
58+
return [torch.ops.my_custom_op.list_mul([grad, b]), torch.ops.my_custom_op.list_mul([grad, a])], None, None
59+
60+
61+
torch.library.register_autograd(
62+
"my_custom_op::list_mul",
63+
backward_of_my_custom_op_list_mul,
64+
setup_context=setup_context_for_my_custom_op_list_mul,
65+
)
66+
67+
68+
has_triton_op = torch.cuda.is_available() and package_available("triton")
69+
if has_triton_op:
70+
import triton
71+
import triton.language as tl
72+
73+
DEVICE = triton.runtime.driver.active.get_active_torch_device()
74+
75+
@triton.jit
76+
def list_mul_triton_kernel(
77+
x_ptr, # *Pointer* to first input vector.
78+
y_ptr, # *Pointer* to second input vector.
79+
output_ptr, # *Pointer* to output vector.
80+
n_elements, # Size of the vector.
81+
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
82+
# NOTE: `constexpr` so it can be used as a shape value.
83+
):
84+
# There are multiple 'programs' processing different data. We identify which program
85+
# we are here:
86+
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
87+
# This program will process inputs that are offset from the initial data.
88+
# For instance, if you had a vector of length 256 and block_size of 64, the programs
89+
# would each access the elements [0:64, 64:128, 128:192, 192:256].
90+
# Note that offsets is a list of pointers:
91+
block_start = pid * BLOCK_SIZE
92+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
93+
# Create a mask to guard memory operations against out-of-bounds accesses.
94+
mask = offsets < n_elements
95+
# Load x and y from DRAM, masking out any extra elements in case the input is not a
96+
# multiple of the block size.
97+
x = tl.load(x_ptr + offsets, mask=mask)
98+
y = tl.load(y_ptr + offsets, mask=mask)
99+
output = x * y
100+
# Write x + y back to DRAM.
101+
tl.store(output_ptr + offsets, output, mask=mask)
102+
103+
@torch.library.triton_op("my_triton_op::list_mul", mutates_args=())
104+
def list_mul_triton(tensors: list[torch.Tensor]) -> list[torch.Tensor]:
105+
if len(tensors) != 2:
106+
raise ValueError("The list of tensors must contain exactly two elements for this operation.")
107+
x = tensors[0]
108+
y = tensors[1]
109+
output = torch.empty_like(x)
110+
n_elements = output.numel()
111+
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
112+
torch.library.wrap_triton(list_mul_triton_kernel)[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
113+
return [output]
114+
115+
torch.library.register_autograd(
116+
"my_triton_op::list_mul",
117+
backward_of_my_custom_op_list_mul,
118+
setup_context=setup_context_for_my_custom_op_list_mul,
119+
)
120+
121+
122+
def _run_test(module_cls, custom_op: CustomOpDef, device: torch.device, dtype: torch.dtype):
123+
SHAPE = (8, 2)
124+
_symbol = _register_custom_op(custom_op)
125+
126+
module = module_cls().to(device=device, dtype=dtype)
127+
jitted = thunder.jit(module, executors=[custom_op_ex])
128+
ref = module_cls().to(device=device, dtype=dtype)
129+
ref.load_state_dict(module.state_dict())
130+
131+
x = torch.testing.make_tensor(SHAPE, device=device, dtype=dtype)
132+
y = torch.testing.make_tensor(SHAPE, device=device, dtype=dtype)
133+
inputs_list = [x, y]
134+
inputs_list_ref = [x.clone().detach() for x in inputs_list]
135+
136+
ref_out = ref(inputs_list_ref)
137+
out = jitted(inputs_list)
138+
torch.testing.assert_close(ref_out, out)
139+
out.mean().backward()
140+
141+
fwd_extrace = thunder.last_traces(jitted)[-1]
142+
bsym: BoundSymbol
143+
custom_ex_bsym_found: bool = False
144+
for bsym in fwd_extrace.bound_symbols:
145+
if bsym.sym.name == _symbol.name and bsym.sym.executor is custom_op_ex:
146+
custom_ex_bsym_found = True
147+
assert custom_ex_bsym_found
148+
149+
150+
@instantiate(
151+
executors=(TorchExecutor,),
152+
devicetypes=(devices.DeviceType.CPU, devices.DeviceType.CUDA),
153+
dtypes=(dtypes.float32,),
154+
)
155+
def test_torch_library_custom_op(_, device: str, dtype: dtypes.dtype):
156+
class MyModule(nn.Module):
157+
def __init__(self):
158+
super().__init__()
159+
self.linear = nn.Linear(2, 2, bias=False)
160+
161+
def forward(self, tensors: list[torch.Tensor]) -> torch.Tensor:
162+
h = torch.ops.my_custom_op.list_mul(tensors)
163+
activation = torch.relu(h[0])
164+
out = self.linear(activation)
165+
return out
166+
167+
_run_test(MyModule, list_mul, devices.to_torch_device(device), dtypes.to_torch_dtype(dtype))
168+
169+
170+
@pytest.mark.skipif(not has_triton_op, reason="triton is not available")
171+
@instantiate(
172+
executors=(TorchExecutor,),
173+
devicetypes=(devices.DeviceType.CUDA,),
174+
dtypes=(dtypes.float32,),
175+
)
176+
def test_torch_library_triton_op(_, device: str, dtype: dtypes.dtype):
177+
class MyModule(nn.Module):
178+
def __init__(self):
179+
super().__init__()
180+
self.linear = nn.Linear(2, 2, bias=False)
181+
182+
def forward(self, tensors: list[torch.Tensor]) -> torch.Tensor:
183+
h = torch.ops.my_triton_op.list_mul(tensors)
184+
activation = torch.relu(h[0])
185+
out = self.linear(activation)
186+
return out
187+
188+
_run_test(MyModule, list_mul_triton, devices.to_torch_device(device), dtypes.to_torch_dtype(dtype))

thunder/torch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6697,7 +6697,7 @@ def _get_fake_arg(inp: Any):
66976697
raise NotImplementedError("Unsupported for NumberProxy.value=None")
66986698
else:
66996699
return inp.value
6700-
elif isinstance(inp, TensorProxy):
6700+
elif isinstance(inp, (TensorProxy, torch.Tensor)):
67016701
return torch.empty(
67026702
inp.shape,
67036703
dtype=to_torch_dtype(inp.dtype),

thunder/torch/custom_op.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import ast
55
import inspect
66

7-
from torch import TensorType
7+
from torch import TensorType, ListType
88

99
from thunder.core import baseutils
1010
from thunder.core.symbol import Symbol
@@ -332,12 +332,20 @@ def _register_custom_op(custom_op: CustomOpDef) -> Symbol:
332332

333333
schema: FunctionSchema = torch_opoverload._schema
334334
schema_arguments: list[Argument] = schema.arguments
335-
tensor_indices: tuple[int] = tuple(i for i, arg in enumerate(schema_arguments) if isinstance(arg.type, TensorType))
335+
tensor_indices: tuple[int] = tuple(
336+
i
337+
for i, arg in enumerate(schema_arguments)
338+
if (isinstance(arg.type, TensorType) or arg.type.isSubtypeOf(ListType.ofTensors()))
339+
)
336340
tensor_arity: int = len(tensor_indices)
337341
baseutils.check(tensor_arity > 0, lambda: f"arity of {custom_op._qualname} should be greater than 0: {schema}")
338342
schema_returns: list[Argument] = schema.returns
339343
return_arity: int = len(schema_returns)
340-
tensor_return_arity: int = len(list(filter(lambda a: isinstance(a.type, TensorType), schema_returns)))
344+
tensor_return_arity: int = len(
345+
list(
346+
filter(lambda a: isinstance(a.type, TensorType) or a.type.isSubtypeOf(ListType.ofTensors()), schema_returns)
347+
)
348+
)
341349
baseutils.check(return_arity == tensor_return_arity, lambda: f"Return values include non-Tensor values: {schema}")
342350

343351
has_autograd_def = _has_autograd_def(custom_op)

0 commit comments

Comments
 (0)