Skip to content

Commit e59dc73

Browse files
committed
update readme and refactor examples
1 parent c566c22 commit e59dc73

File tree

12 files changed

+646
-72
lines changed

12 files changed

+646
-72
lines changed

README.md

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,32 +77,69 @@ Here's how `weco` can be applied to common ML engineering tasks:
7777

7878
### Examples
7979

80-
**Example 1: Optimizing PyTorch operations**
80+
**Example 1: Optimizing PyTorch simple operations**
8181

8282
```bash
83-
weco --source examples/simple-torch/optimize.py \
84-
--eval-command "python examples/simple-torch/evaluate.py --solution-path examples/simple-torch/optimize.py --device mps" \
83+
cd examples/hello-kernel-world
84+
pip install torch
85+
weco --source optimize.py \
86+
--eval-command "python evaluate.py --solution-path optimize.py --device cpu" \
8587
--metric speedup \
8688
--maximize true \
8789
--steps 15 \
88-
--model o3-mini \
90+
--model claude-3-7-sonnet-20250219 \
8991
--additional-instructions "Fuse operations in the forward method while ensuring the max float deviation remains small. Maintain the same format of the code."
9092
```
9193

94+
Note that if you have an NVIDIA gpu, change the device to `cuda`. If you are running this on Apple Silicon, set it to `mps`.
95+
9296
**Example 2: Optimizing MLX operations with instructions from a file**
9397

94-
Sometimes, additional context or instructions are too complex for a single command-line string. You can provide a path to a file containing these instructions.
98+
Lets optimize a 2D convolution operation in [`mlx`](https://github.com/ml-explore/mlx) using [Metal](https://developer.apple.com/documentation/metal/). Sometimes, additional context or instructions are too complex for a single command-line string. You can provide a path to a file containing these instructions.
9599

96100
```bash
97-
weco --source examples/simple-mlx/optimize.py \
98-
--eval-command "python examples/simple-mlx/evaluate.py --solution-path examples/simple-mlx/optimize.py" \
101+
cd examples/metal
102+
pip install mlx
103+
weco --source optimize.py \
104+
--eval-command "python evaluate.py --solution-path optimize.py" \
99105
--metric speedup \
100106
--maximize true \
101107
--steps 30 \
102108
--model o3-mini \
103-
--additional-instructions examples/simple-mlx/metal-examples.rst
109+
--additional-instructions examples.rst
104110
```
105111

112+
**Example 3: Level Agnostic Optimization: Causal Self Attention with Triton & CUDA**
113+
114+
Given how useful causal multihead self attention is to transformers, we've seen its wide adoption across ML engineering and AI research. Its great to keep things at a high-level (in PyTorch) when doing research, but when moving to production you often need to write highly customized low-level kernels to make things run as fast as they can. The `weco` CLI can optimize kernels across a variety of different abstraction levels and frameworks. Example 2 uses Metal but lets explore two more frameworks:
115+
116+
1. [Triton](https://github.com/triton-lang/triton)
117+
```bash
118+
cd examples/triton
119+
pip install torch triton
120+
weco --source optimize.py \
121+
--eval-command "python evaluate.py --solution-path optimize.py" \
122+
--metric speedup \
123+
--maximize true \
124+
--steps 30 \
125+
--model gemini-2.5-pro-preview-03-25 \
126+
--additional-instructions "Use triton to optimize the code while ensuring a small max float diff. Maintain the same code format."
127+
```
128+
129+
2. [CUDA](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html)
130+
```bash
131+
cd examples/cuda
132+
pip install torch
133+
weco --source optimize.py \
134+
--eval-command "python evaluate.py --solution-path optimize.py" \
135+
--metric speedup \
136+
--maximize true \
137+
--steps 30 \
138+
--model gemini-2.5-pro-preview-03-25 \
139+
--additional-instructions guide.md
140+
```
141+
142+
106143
---
107144
108145
### Command Line Arguments

examples/cuda/evaluate.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import time
2+
import sys
3+
import os
4+
import pathlib
5+
import importlib
6+
import traceback
7+
import torch
8+
import torch.nn as nn
9+
import torch.nn.functional as F
10+
import math
11+
12+
13+
########################################################
14+
# Baseline
15+
########################################################
16+
class Model(nn.Module):
17+
"""
18+
A vanilla multi-head masked self-attention layer with a projection at the end.
19+
"""
20+
21+
def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, max_seqlen):
22+
super().__init__()
23+
assert n_embd % n_head == 0
24+
# key, query, value projections for all heads, but in a batch
25+
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
26+
# output projection
27+
self.c_proj = nn.Linear(n_embd, n_embd)
28+
# regularization
29+
self.attn_dropout = nn.Dropout(attn_pdrop)
30+
self.resid_dropout = nn.Dropout(resid_pdrop)
31+
# causal mask to ensure that attention is only applied to the left in the input sequence
32+
self.register_buffer("bias", torch.tril(torch.ones(max_seqlen, max_seqlen)).view(1, 1, max_seqlen, max_seqlen))
33+
self.n_head = n_head
34+
self.n_embd = n_embd
35+
36+
def forward(self, x):
37+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
38+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
39+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
40+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
41+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
42+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
43+
44+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
45+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
46+
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
47+
att = F.softmax(att, dim=-1)
48+
att = self.attn_dropout(att)
49+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
50+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
51+
# output projection
52+
y = self.resid_dropout(self.c_proj(y))
53+
return y
54+
55+
56+
########################################################
57+
# Weco Solution
58+
########################################################
59+
def load_module_from_path(module_path: str, add_to_sys_modules: bool = False):
60+
# Clean out all old compiled extensions to prevent namespace collisions during build
61+
module_path = pathlib.Path(module_path)
62+
name = module_path.stem
63+
spec = importlib.util.spec_from_file_location(name, module_path)
64+
mod = importlib.util.module_from_spec(spec) # type: ignore
65+
if add_to_sys_modules:
66+
sys.modules[name] = mod
67+
spec.loader.exec_module(mod) # type: ignore
68+
return mod
69+
70+
71+
########################################################
72+
# Benchmark
73+
########################################################
74+
os.environ["MAX_JOBS"] = "1" # number of workers for building with ninja
75+
76+
77+
def get_inputs(batch_size, seq_len, n_embd, device):
78+
return torch.randn(batch_size, seq_len, n_embd, device=device, dtype=torch.float32)
79+
80+
81+
def bench(f, inputs, n_warmup, n_rep):
82+
with torch.no_grad():
83+
# warmup
84+
for _ in range(n_warmup):
85+
f(inputs) # noqa
86+
87+
# benchmark
88+
t_avg = 0.0
89+
for _ in range(n_rep):
90+
torch.cuda.empty_cache() # Clear cache before timing
91+
start_time = time.time()
92+
f(inputs)
93+
torch.cuda.synchronize() # Wait for all computations to complete
94+
t_avg += time.time() - start_time
95+
t_avg /= n_rep * 1e-3
96+
return t_avg
97+
98+
99+
if __name__ == "__main__":
100+
import argparse
101+
102+
parser = argparse.ArgumentParser()
103+
parser.add_argument("--solution-path", type=str, required=True)
104+
args = parser.parse_args()
105+
106+
# benchmarking parameters
107+
n_correctness_trials = 10
108+
n_warmup = 1000
109+
n_rep = 5000
110+
111+
# init parameters
112+
max_seqlen = 512
113+
seq_len = 256
114+
n_embd = 768
115+
n_head = 8
116+
# turn off dropout to measure correctness well
117+
attn_pdrop = 0.0
118+
resid_pdrop = 0.0
119+
120+
# input parameters
121+
batch_size = 32
122+
123+
# load solution module
124+
try:
125+
torch.manual_seed(0)
126+
solution_module = load_module_from_path(args.solution_path, add_to_sys_modules=False)
127+
solution_model = solution_module.Model(
128+
n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, max_seqlen=max_seqlen
129+
).to("cuda")
130+
assert isinstance(solution_model, nn.Module)
131+
except Exception:
132+
print(f"Candidate module initialization failed: {traceback.format_exc()}")
133+
exit(1)
134+
135+
torch.manual_seed(0)
136+
baseline_model = Model(
137+
n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, max_seqlen=max_seqlen
138+
).to("cuda")
139+
140+
# measure correctness
141+
max_diff_avg = 0
142+
for _ in range(n_correctness_trials):
143+
inputs = get_inputs(batch_size=batch_size, seq_len=seq_len, n_embd=n_embd, device="cuda")
144+
with torch.no_grad():
145+
baseline_output = baseline_model(inputs)
146+
optimized_output = solution_model(inputs)
147+
max_diff_avg += torch.max(torch.abs(optimized_output - baseline_output))
148+
max_diff_avg /= n_correctness_trials
149+
print(f"max float diff between values of baseline and optimized model: {max_diff_avg}")
150+
151+
# measure performance
152+
inputs = get_inputs(batch_size=batch_size, seq_len=seq_len, n_embd=n_embd, device="cuda")
153+
t_avg_baseline = bench(baseline_model, inputs, n_warmup, n_rep)
154+
print(f"baseline time: {t_avg_baseline:.2f}ms")
155+
t_avg_optimized = bench(solution_model, inputs, n_warmup, n_rep)
156+
print(f"optimized time: {t_avg_optimized:.2f}ms")
157+
print(f"speedup: {t_avg_baseline / t_avg_optimized:.2f}x")

examples/cuda/guide.md

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Writing In-line CUDA Kernels: 101
2+
3+
This document outlines the strategy to improve speedup by writing fused and optimized CUDA kernels using a single-file implementation.
4+
5+
## Requirements
6+
7+
- **Single-File Implementation:** Develop fused CUDA kernels within one file.
8+
- **No Fallback Implementation:** Do not include any alternative or fallback code.
9+
- **Simplicity & Readability:** Write simple, easy-to-understand code and include clear comments.
10+
- **Avoid Templates:** Use plain fused kernel functions without templates.
11+
- **Multiple Kernels Allowed:** You can define more than one kernel in the file if needed.
12+
- **Model Class Requirement:** The solution must include a class `Model` (an instance of `nn.Module`), with the main computation in its `forward` method.
13+
- **Preserve Initialization:** Do not change the initialization of the `Model` class.
14+
- **Focus on Efficiency:** Concentrate solely on efficient PyTorch and CUDA coding without capturing logs.
15+
- **Error Handling:** Any terminal output or errors will be reviewed by an LLM for feedback.
16+
17+
## GPU Hardware Specifications
18+
19+
Here are some details on the hardware you have access to.
20+
21+
```json
22+
{
23+
"GPU Architecture": "Ampere",
24+
"GPU Memory": "40GB",
25+
"Memory Bandwidth": "1935 GB/s",
26+
"FP64 TFLOPS": "9.7",
27+
"FP64 Tensor Core TFLOPS": "19.5",
28+
"FP32 TFLOPS": "19.5",
29+
"TF32 Tensor Core TFLOPS": "156 (312 with sparsity)",
30+
"BFLOAT16 Tensore Core TFLOPS": "312 (624 with sparsity)",
31+
"FP16 Tensor Core TFLOPS": "312 (624 with sparsity)",
32+
"INT8 Tensor Core TOPS": "624 (1248 with sparsity)",
33+
"Register File Size": "64K 32-bit registers per SM",
34+
"Maximum number of registers per thread": "255",
35+
"Maximum number of thread blocks per SM": "32",
36+
"Shared memory capacity per SM": "164 KB",
37+
"Maximum shared memory per thread block": "163 KB"
38+
}
39+
```
40+
41+
## Baseline Code
42+
43+
The baseline implementation of the `Model` class simply performs an element-wise addition.
44+
45+
```python
46+
import torch
47+
import torch.nn as nn
48+
import torch.nn.functional as F
49+
50+
class Model(nn.Module):
51+
def __init__(self) -> None:
52+
super().__init__()
53+
54+
def forward(self, a, b):
55+
return a + b
56+
```
57+
58+
## Optimized Code
59+
60+
The optimized version employs a custom CUDA kernel for fused element-wise addition. The kernel is defined and compiled inline using PyTorch's `load_inline`.
61+
62+
```python
63+
import torch
64+
import torch.nn as nn
65+
import torch.nn.functional as F
66+
from torch.utils.cpp_extension import load_inline
67+
68+
# Define the custom CUDA kernel for element-wise addition
69+
elementwise_add_source = '''
70+
#include <torch/extension.h>
71+
#include <cuda_runtime.h>
72+
73+
// CUDA kernel for element-wise addition
74+
__global__ void elementwise_add_kernel(const float* a, const float* b, float* out, int size) {
75+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
76+
if (idx < size) {
77+
out[idx] = a[idx] + b[idx];
78+
}
79+
}
80+
81+
// Launch function for the CUDA kernel
82+
torch::Tensor elementwise_add_cuda(torch::Tensor a, torch::Tensor b) {
83+
auto size = a.numel();
84+
auto out = torch::zeros_like(a);
85+
const int block_size = 256;
86+
const int num_blocks = (size + block_size - 1) / block_size;
87+
elementwise_add_kernel<<<num_blocks, block_size>>>(a.data_ptr<float>(), b.data_ptr<float>(), out.data_ptr<float>(), size);
88+
return out;
89+
}
90+
'''
91+
92+
# C++ function prototype declaration
93+
elementwise_add_cpp_source = "torch::Tensor elementwise_add_cuda(torch::Tensor a, torch::Tensor b);"
94+
95+
# Compile the inline CUDA code for element-wise addition
96+
elementwise_add = load_inline(
97+
name="elementwise_add",
98+
cpp_sources=elementwise_add_cpp_source,
99+
cuda_sources=elementwise_add_source,
100+
functions=["elementwise_add_cuda"],
101+
verbose=True,
102+
extra_cflags=[""],
103+
extra_ldflags=[""],
104+
)
105+
106+
class Model(nn.Module):
107+
def __init__(self) -> None:
108+
super().__init__()
109+
self.elementwise_add = elementwise_add
110+
111+
def forward(self, a, b):
112+
return self.elementwise_add.elementwise_add_cuda(a, b)
113+
```

examples/cuda/optimize.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import math
5+
6+
7+
class Model(nn.Module):
8+
"""
9+
A vanilla multi-head masked self-attention layer with a projection at the end.
10+
"""
11+
12+
def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, max_seqlen):
13+
super().__init__()
14+
assert n_embd % n_head == 0
15+
# key, query, value projections for all heads, but in a batch
16+
self.c_attn = nn.Linear(n_embd, 3 * n_embd)
17+
# output projection
18+
self.c_proj = nn.Linear(n_embd, n_embd)
19+
# regularization
20+
self.attn_dropout = nn.Dropout(attn_pdrop)
21+
self.resid_dropout = nn.Dropout(resid_pdrop)
22+
# causal mask to ensure that attention is only applied to the left in the input sequence
23+
self.register_buffer("bias", torch.tril(torch.ones(max_seqlen, max_seqlen)).view(1, 1, max_seqlen, max_seqlen))
24+
self.n_head = n_head
25+
self.n_embd = n_embd
26+
27+
def forward(self, x):
28+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
29+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
30+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
31+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
32+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
33+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
34+
35+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
36+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
37+
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
38+
att = F.softmax(att, dim=-1)
39+
att = self.attn_dropout(att)
40+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
41+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
42+
# output projection
43+
y = self.resid_dropout(self.c_proj(y))
44+
return y

0 commit comments

Comments
 (0)