Skip to content

Commit 182b103

Browse files
aditvenkAditya Venkataraman
andauthored
Add backward operator support for vector_exp (#501)
- Add Triton kernel for backward. Testing: Forward: ``` $ python run.py --op vector_exp --metrics accuracy First-k mode: Selected 16 sequential inputs starting from index 0 (total available: 16) Input IDs to run: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:02<00:00, 6.34it/s] x_val triton_exp-accuracy torch_compile_exp-accuracy --------- --------------------- ---------------------------- 4096 1 1 8192 1 1 16384 1 1 32768 1 1 65536 1 1 131072 1 1 262144 1 1 524288 1 1 1048576 1 1 2097152 1 1 4194304 1 1 8388608 1 1 16777216 1 1 33554432 1 1 67108864 1 1 134217728 1 1 average 1 1 ``` Backward: ``` $ python run.py --op vector_exp --metrics accuracy --bwd First-k mode: Selected 16 sequential inputs starting from index 0 (total available: 16) Input IDs to run: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:02<00:00, 5.49it/s] x_val triton_exp-accuracy torch_compile_exp-accuracy --------- --------------------- ---------------------------- 4096 1 1 8192 1 1 16384 1 1 32768 1 1 65536 1 1 131072 1 1 262144 1 1 524288 1 1 1048576 1 1 2097152 1 1 4194304 1 1 8388608 1 1 16777216 1 1 33554432 1 1 67108864 1 1 134217728 1 1 average 1 1 ``` Co-authored-by: Aditya Venkataraman <[email protected]>
1 parent 55891ad commit 182b103

File tree

2 files changed

+110
-17
lines changed

2 files changed

+110
-17
lines changed

tritonbench/operators/vector_exp/kernels.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
import triton
23
import triton.language as tl
34

@@ -36,3 +37,90 @@ def triton_exp_kernel(
3637
if profile_mem is not None:
3738
end = time()
3839
tl.store(profile_mem + pid, end - start)
40+
41+
42+
@triton.jit
43+
def triton_exp_backward_kernel(
44+
grad_output_ptr, # *Pointer* to grad_output vector.
45+
output_ptr, # *Pointer* to forward pass output vector (exp(x)).
46+
grad_input_ptr, # *Pointer* to grad_input vector.
47+
n_elements, # Size of the vector.
48+
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
49+
profile_mem=None, # *Pointer* to profile_mem.
50+
):
51+
if profile_mem is not None:
52+
start = time()
53+
54+
# There are multiple 'programs' processing different data. We identify which program
55+
# we are here:
56+
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
57+
58+
# This program will process inputs that are offset from the initial data.
59+
block_start = pid * BLOCK_SIZE
60+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
61+
# Create a mask to guard memory operations against out-of-bounds accesses.
62+
mask = offsets < n_elements
63+
64+
# Load grad_output and output from DRAM
65+
grad_output = tl.load(grad_output_ptr + offsets, mask=mask)
66+
output = tl.load(output_ptr + offsets, mask=mask)
67+
68+
# Compute grad_input = grad_output * output (since d/dx(exp(x)) = exp(x))
69+
grad_input = grad_output * output
70+
71+
# Write grad_input back to DRAM.
72+
tl.store(grad_input_ptr + offsets, grad_input, mask=mask)
73+
74+
if profile_mem is not None:
75+
end = time()
76+
tl.store(profile_mem + pid, end - start)
77+
78+
79+
class TritonExpFunction(torch.autograd.Function):
80+
@staticmethod
81+
def forward(
82+
ctx, x: torch.Tensor, block_size: int = 1024, profile_mem: torch.Tensor = None
83+
):
84+
# Allocate output tensor
85+
output = torch.empty_like(x)
86+
n_elements = output.numel()
87+
88+
# Launch grid - number of blocks needed
89+
grid = lambda meta: (triton.cdiv(n_elements, block_size),)
90+
91+
# Launch forward kernel
92+
triton_exp_kernel[grid](
93+
x, output, n_elements, BLOCK_SIZE=block_size, profile_mem=profile_mem
94+
)
95+
96+
# Save output for backward pass
97+
ctx.save_for_backward(output)
98+
ctx.block_size = block_size
99+
ctx.profile_mem = profile_mem
100+
101+
return output
102+
103+
@staticmethod
104+
def backward(ctx, grad_output: torch.Tensor):
105+
# Retrieve saved tensors
106+
(output,) = ctx.saved_tensors
107+
108+
# Allocate grad_input tensor
109+
grad_input = torch.empty_like(grad_output)
110+
n_elements = grad_output.numel()
111+
112+
# Launch grid - number of blocks needed
113+
grid = lambda meta: (triton.cdiv(n_elements, ctx.block_size),)
114+
115+
# Launch backward kernel
116+
triton_exp_backward_kernel[grid](
117+
grad_output,
118+
output,
119+
grad_input,
120+
n_elements,
121+
BLOCK_SIZE=ctx.block_size,
122+
profile_mem=ctx.profile_mem,
123+
)
124+
125+
# Return gradients (None for block_size and profile_mem as they don't need gradients)
126+
return grad_input, None, None

tritonbench/operators/vector_exp/operator.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
register_metric,
1111
)
1212

13-
from .kernels import triton_exp_kernel
13+
from .kernels import TritonExpFunction
1414

1515

1616
class Operator(BenchmarkOperator):
@@ -46,27 +46,14 @@ def duration(self, fn, example_inputs, metrics: BenchmarkOperatorMetrics):
4646

4747
@register_benchmark()
4848
def triton_exp(self, x: torch.Tensor):
49-
# We need to preallocate the output.
50-
output = torch.empty_like(x)
51-
n_elements = output.numel()
52-
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
53-
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].
54-
# In this case, we use a 1D grid where the size is the number of blocks:
55-
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
56-
# NOTE:
57-
# - Each torch.tensor object is implicitly converted into a pointer to its first element.
58-
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
59-
# - Don't forget to pass meta-parameters as keywords arguments.
60-
49+
n_elements = x.numel()
6150
# Prepare a memory buffer to store the profiled data, with the size equal to the number of programs.
6251
BLOCK_SIZE = 1024
6352
n_programs = triton.cdiv(n_elements, BLOCK_SIZE)
6453
profile_mem = torch.empty(n_programs, dtype=torch.int64, device=self.device)
6554

6655
def _inner():
67-
triton_exp_kernel[grid](
68-
x, output, n_elements, BLOCK_SIZE=1024, profile_mem=profile_mem
69-
)
56+
output = TritonExpFunction.apply(x, BLOCK_SIZE, profile_mem)
7057
return {"output": output, "profile_mem": profile_mem}
7158

7259
return _inner
@@ -133,5 +120,23 @@ def _plot(size, provider):
133120

134121
def get_input_iter(self) -> Generator:
135122
for size in self.get_x_vals():
136-
x = torch.rand(size, device=self.device, dtype=self.dtype)
123+
x = torch.rand(
124+
size, device=self.device, dtype=self.dtype, requires_grad=True
125+
)
137126
yield (x,)
127+
128+
def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
129+
def _bwd():
130+
x = self.example_inputs[0]
131+
# clear existing grad
132+
x.grad = None
133+
y = fwd_fn()["output"]
134+
dy = torch.randn_like(y)
135+
y.backward(dy, retain_graph=True)
136+
return {"output": x.grad}
137+
138+
return _bwd
139+
140+
def get_grad_to_none(self, args) -> List[torch.Tensor]:
141+
x = args[0]
142+
return [x]

0 commit comments

Comments
 (0)