|
| 1 | +from typing import TYPE_CHECKING |
| 2 | + |
| 3 | +from lightning_utilities.core.imports import package_available |
| 4 | +import numpy as np |
| 5 | +import pytest |
| 6 | +import torch |
| 7 | +import torch.nn as nn |
| 8 | +from torch._library.custom_ops import CustomOpDef |
| 9 | + |
| 10 | +import thunder |
| 11 | +from thunder.core import dtypes |
| 12 | +from thunder.core import devices |
| 13 | +from thunder.torch.custom_op import _register_custom_op |
| 14 | +from thunder.executors.custom_op_ex import custom_op_ex |
| 15 | +from thunder.tests.framework import TorchExecutor |
| 16 | +from thunder.tests.framework import instantiate |
| 17 | + |
| 18 | +if TYPE_CHECKING: |
| 19 | + from thunder.core.symbol import BoundSymbol |
| 20 | + |
| 21 | + |
| 22 | +@torch.library.custom_op("my_custom_op::list_mul", mutates_args=()) |
| 23 | +def list_mul(tensors: list[torch.Tensor], c: float | None = None, d: str = "") -> list[torch.Tensor]: |
| 24 | + if len(tensors) != 2: |
| 25 | + raise ValueError("The list of tensors must contain exactly two elements for this operation.") |
| 26 | + return [tensors[0] * tensors[1]] |
| 27 | + |
| 28 | + |
| 29 | +@torch.library.register_kernel("my_custom_op::list_mul", "cpu") |
| 30 | +def _(tensors: list[torch.Tensor], c: float | None = None, d: str = "") -> list[torch.Tensor]: |
| 31 | + return [ |
| 32 | + torch.from_numpy( |
| 33 | + np.multiply( |
| 34 | + tensors[0].numpy(force=True), |
| 35 | + tensors[1].numpy(force=True), |
| 36 | + ) |
| 37 | + ) |
| 38 | + ] |
| 39 | + |
| 40 | + |
| 41 | +@torch.library.register_kernel("my_custom_op::list_mul", "cuda") |
| 42 | +def _(tensors: list[torch.Tensor], c: float | None = None, d: str = "") -> list[torch.Tensor]: |
| 43 | + return [tensors[0] * tensors[1]] |
| 44 | + |
| 45 | + |
| 46 | +@torch.library.register_fake("my_custom_op::list_mul") |
| 47 | +def _(tensors: list[torch.Tensor], c: float | None = None, d: str = "") -> list[torch.Tensor]: |
| 48 | + return [torch.empty_like(tensors[0])] |
| 49 | + |
| 50 | + |
| 51 | +def setup_context_for_my_custom_op_list_mul(ctx, inputs, output) -> None: |
| 52 | + tensors_list, *_ = inputs |
| 53 | + ctx.save_for_backward(tensors_list[0], tensors_list[1]) |
| 54 | + |
| 55 | + |
| 56 | +def backward_of_my_custom_op_list_mul(ctx, grad) -> tuple[list[torch.Tensor], None, None]: |
| 57 | + a, b = ctx.saved_tensors |
| 58 | + return [torch.ops.my_custom_op.list_mul([grad, b]), torch.ops.my_custom_op.list_mul([grad, a])], None, None |
| 59 | + |
| 60 | + |
| 61 | +torch.library.register_autograd( |
| 62 | + "my_custom_op::list_mul", |
| 63 | + backward_of_my_custom_op_list_mul, |
| 64 | + setup_context=setup_context_for_my_custom_op_list_mul, |
| 65 | +) |
| 66 | + |
| 67 | + |
| 68 | +has_triton_op = torch.cuda.is_available() and package_available("triton") |
| 69 | +if has_triton_op: |
| 70 | + import triton |
| 71 | + import triton.language as tl |
| 72 | + |
| 73 | + DEVICE = triton.runtime.driver.active.get_active_torch_device() |
| 74 | + |
| 75 | + @triton.jit |
| 76 | + def list_mul_triton_kernel( |
| 77 | + x_ptr, # *Pointer* to first input vector. |
| 78 | + y_ptr, # *Pointer* to second input vector. |
| 79 | + output_ptr, # *Pointer* to output vector. |
| 80 | + n_elements, # Size of the vector. |
| 81 | + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. |
| 82 | + # NOTE: `constexpr` so it can be used as a shape value. |
| 83 | + ): |
| 84 | + # There are multiple 'programs' processing different data. We identify which program |
| 85 | + # we are here: |
| 86 | + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. |
| 87 | + # This program will process inputs that are offset from the initial data. |
| 88 | + # For instance, if you had a vector of length 256 and block_size of 64, the programs |
| 89 | + # would each access the elements [0:64, 64:128, 128:192, 192:256]. |
| 90 | + # Note that offsets is a list of pointers: |
| 91 | + block_start = pid * BLOCK_SIZE |
| 92 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 93 | + # Create a mask to guard memory operations against out-of-bounds accesses. |
| 94 | + mask = offsets < n_elements |
| 95 | + # Load x and y from DRAM, masking out any extra elements in case the input is not a |
| 96 | + # multiple of the block size. |
| 97 | + x = tl.load(x_ptr + offsets, mask=mask) |
| 98 | + y = tl.load(y_ptr + offsets, mask=mask) |
| 99 | + output = x * y |
| 100 | + # Write x + y back to DRAM. |
| 101 | + tl.store(output_ptr + offsets, output, mask=mask) |
| 102 | + |
| 103 | + @torch.library.triton_op("my_triton_op::list_mul", mutates_args=()) |
| 104 | + def list_mul_triton(tensors: list[torch.Tensor]) -> list[torch.Tensor]: |
| 105 | + if len(tensors) != 2: |
| 106 | + raise ValueError("The list of tensors must contain exactly two elements for this operation.") |
| 107 | + x = tensors[0] |
| 108 | + y = tensors[1] |
| 109 | + output = torch.empty_like(x) |
| 110 | + n_elements = output.numel() |
| 111 | + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| 112 | + torch.library.wrap_triton(list_mul_triton_kernel)[grid](x, y, output, n_elements, BLOCK_SIZE=1024) |
| 113 | + return [output] |
| 114 | + |
| 115 | + torch.library.register_autograd( |
| 116 | + "my_triton_op::list_mul", |
| 117 | + backward_of_my_custom_op_list_mul, |
| 118 | + setup_context=setup_context_for_my_custom_op_list_mul, |
| 119 | + ) |
| 120 | + |
| 121 | + |
| 122 | +def _run_test(module_cls, custom_op: CustomOpDef, device: torch.device, dtype: torch.dtype): |
| 123 | + SHAPE = (8, 2) |
| 124 | + _symbol = _register_custom_op(custom_op) |
| 125 | + |
| 126 | + module = module_cls().to(device=device, dtype=dtype) |
| 127 | + jitted = thunder.jit(module, executors=[custom_op_ex]) |
| 128 | + ref = module_cls().to(device=device, dtype=dtype) |
| 129 | + ref.load_state_dict(module.state_dict()) |
| 130 | + |
| 131 | + x = torch.testing.make_tensor(SHAPE, device=device, dtype=dtype) |
| 132 | + y = torch.testing.make_tensor(SHAPE, device=device, dtype=dtype) |
| 133 | + inputs_list = [x, y] |
| 134 | + inputs_list_ref = [x.clone().detach() for x in inputs_list] |
| 135 | + |
| 136 | + ref_out = ref(inputs_list_ref) |
| 137 | + out = jitted(inputs_list) |
| 138 | + torch.testing.assert_close(ref_out, out) |
| 139 | + out.mean().backward() |
| 140 | + |
| 141 | + fwd_extrace = thunder.last_traces(jitted)[-1] |
| 142 | + bsym: BoundSymbol |
| 143 | + custom_ex_bsym_found: bool = False |
| 144 | + for bsym in fwd_extrace.bound_symbols: |
| 145 | + if bsym.sym.name == _symbol.name and bsym.sym.executor is custom_op_ex: |
| 146 | + custom_ex_bsym_found = True |
| 147 | + assert custom_ex_bsym_found |
| 148 | + |
| 149 | + |
| 150 | +@instantiate( |
| 151 | + executors=(TorchExecutor,), |
| 152 | + devicetypes=(devices.DeviceType.CPU, devices.DeviceType.CUDA), |
| 153 | + dtypes=(dtypes.float32,), |
| 154 | +) |
| 155 | +def test_torch_library_custom_op(_, device: str, dtype: dtypes.dtype): |
| 156 | + class MyModule(nn.Module): |
| 157 | + def __init__(self): |
| 158 | + super().__init__() |
| 159 | + self.linear = nn.Linear(2, 2, bias=False) |
| 160 | + |
| 161 | + def forward(self, tensors: list[torch.Tensor]) -> torch.Tensor: |
| 162 | + h = torch.ops.my_custom_op.list_mul(tensors) |
| 163 | + activation = torch.relu(h[0]) |
| 164 | + out = self.linear(activation) |
| 165 | + return out |
| 166 | + |
| 167 | + _run_test(MyModule, list_mul, devices.to_torch_device(device), dtypes.to_torch_dtype(dtype)) |
| 168 | + |
| 169 | + |
| 170 | +@pytest.mark.skipif(not has_triton_op, reason="triton is not available") |
| 171 | +@instantiate( |
| 172 | + executors=(TorchExecutor,), |
| 173 | + devicetypes=(devices.DeviceType.CUDA,), |
| 174 | + dtypes=(dtypes.float32,), |
| 175 | +) |
| 176 | +def test_torch_library_triton_op(_, device: str, dtype: dtypes.dtype): |
| 177 | + class MyModule(nn.Module): |
| 178 | + def __init__(self): |
| 179 | + super().__init__() |
| 180 | + self.linear = nn.Linear(2, 2, bias=False) |
| 181 | + |
| 182 | + def forward(self, tensors: list[torch.Tensor]) -> torch.Tensor: |
| 183 | + h = torch.ops.my_triton_op.list_mul(tensors) |
| 184 | + activation = torch.relu(h[0]) |
| 185 | + out = self.linear(activation) |
| 186 | + return out |
| 187 | + |
| 188 | + _run_test(MyModule, list_mul_triton, devices.to_torch_device(device), dtypes.to_torch_dtype(dtype)) |
0 commit comments