|
| 1 | +import torch |
1 | 2 | import triton
|
2 | 3 | import triton.language as tl
|
3 | 4 |
|
@@ -36,3 +37,90 @@ def triton_exp_kernel(
|
36 | 37 | if profile_mem is not None:
|
37 | 38 | end = time()
|
38 | 39 | tl.store(profile_mem + pid, end - start)
|
| 40 | + |
| 41 | + |
| 42 | +@triton.jit |
| 43 | +def triton_exp_backward_kernel( |
| 44 | + grad_output_ptr, # *Pointer* to grad_output vector. |
| 45 | + output_ptr, # *Pointer* to forward pass output vector (exp(x)). |
| 46 | + grad_input_ptr, # *Pointer* to grad_input vector. |
| 47 | + n_elements, # Size of the vector. |
| 48 | + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. |
| 49 | + profile_mem=None, # *Pointer* to profile_mem. |
| 50 | +): |
| 51 | + if profile_mem is not None: |
| 52 | + start = time() |
| 53 | + |
| 54 | + # There are multiple 'programs' processing different data. We identify which program |
| 55 | + # we are here: |
| 56 | + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. |
| 57 | + |
| 58 | + # This program will process inputs that are offset from the initial data. |
| 59 | + block_start = pid * BLOCK_SIZE |
| 60 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 61 | + # Create a mask to guard memory operations against out-of-bounds accesses. |
| 62 | + mask = offsets < n_elements |
| 63 | + |
| 64 | + # Load grad_output and output from DRAM |
| 65 | + grad_output = tl.load(grad_output_ptr + offsets, mask=mask) |
| 66 | + output = tl.load(output_ptr + offsets, mask=mask) |
| 67 | + |
| 68 | + # Compute grad_input = grad_output * output (since d/dx(exp(x)) = exp(x)) |
| 69 | + grad_input = grad_output * output |
| 70 | + |
| 71 | + # Write grad_input back to DRAM. |
| 72 | + tl.store(grad_input_ptr + offsets, grad_input, mask=mask) |
| 73 | + |
| 74 | + if profile_mem is not None: |
| 75 | + end = time() |
| 76 | + tl.store(profile_mem + pid, end - start) |
| 77 | + |
| 78 | + |
| 79 | +class TritonExpFunction(torch.autograd.Function): |
| 80 | + @staticmethod |
| 81 | + def forward( |
| 82 | + ctx, x: torch.Tensor, block_size: int = 1024, profile_mem: torch.Tensor = None |
| 83 | + ): |
| 84 | + # Allocate output tensor |
| 85 | + output = torch.empty_like(x) |
| 86 | + n_elements = output.numel() |
| 87 | + |
| 88 | + # Launch grid - number of blocks needed |
| 89 | + grid = lambda meta: (triton.cdiv(n_elements, block_size),) |
| 90 | + |
| 91 | + # Launch forward kernel |
| 92 | + triton_exp_kernel[grid]( |
| 93 | + x, output, n_elements, BLOCK_SIZE=block_size, profile_mem=profile_mem |
| 94 | + ) |
| 95 | + |
| 96 | + # Save output for backward pass |
| 97 | + ctx.save_for_backward(output) |
| 98 | + ctx.block_size = block_size |
| 99 | + ctx.profile_mem = profile_mem |
| 100 | + |
| 101 | + return output |
| 102 | + |
| 103 | + @staticmethod |
| 104 | + def backward(ctx, grad_output: torch.Tensor): |
| 105 | + # Retrieve saved tensors |
| 106 | + (output,) = ctx.saved_tensors |
| 107 | + |
| 108 | + # Allocate grad_input tensor |
| 109 | + grad_input = torch.empty_like(grad_output) |
| 110 | + n_elements = grad_output.numel() |
| 111 | + |
| 112 | + # Launch grid - number of blocks needed |
| 113 | + grid = lambda meta: (triton.cdiv(n_elements, ctx.block_size),) |
| 114 | + |
| 115 | + # Launch backward kernel |
| 116 | + triton_exp_backward_kernel[grid]( |
| 117 | + grad_output, |
| 118 | + output, |
| 119 | + grad_input, |
| 120 | + n_elements, |
| 121 | + BLOCK_SIZE=ctx.block_size, |
| 122 | + profile_mem=ctx.profile_mem, |
| 123 | + ) |
| 124 | + |
| 125 | + # Return gradients (None for block_size and profile_mem as they don't need gradients) |
| 126 | + return grad_input, None, None |
0 commit comments