You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This basic example shows how to optimize a simple PyTorch functionfor speedup.
56
+
57
+
For more advanced examples, including **[Metal/MLX](/examples/metal/README.md), [Triton](/examples/triton/README.md), [CUDA kernel optimization](/examples/cuda/README.md)**, and **[ML model optimization](/examples/spaceship-titanic/README.md)t**, please see the `README.md` files within the corresponding subdirectories under the [`examples/`](./examples/) folder.
--additional-instructions "Fuse operations in the forward method while ensuring the max float deviation remains small. Maintain the same format of the code."
67
74
```
68
75
69
-
Note that if you have an NVIDIA gpu, change the device to `cuda`. If you are running this on Apple Silicon, set it to `mps`.
70
-
71
-
**Example 2: Optimizing MLX operations with instructions from a file**
72
-
73
-
Lets optimize a 2D convolution operation in [`mlx`](https://github.com/ml-explore/mlx) using [Metal](https://developer.apple.com/documentation/metal/). Sometimes, additional context or instructions are too complex for a single command-line string. You can provide a path to a file containing these instructions.
Given how useful causal multihead self attention is to transformers, we've seen its wide adoption across ML engineering and AI research. Its great to keep things at a high-level (in PyTorch) when doing research, but when moving to production you often need to write highly customized low-level kernels to make things run as fast as they can. The `weco` CLI can optimize kernels across a variety of different abstraction levels and frameworks. Example 2 uses Metal but lets explore two more frameworks:
This example demonstrates optimizing a script for a Kaggle competition ([Spaceship Titanic](https://www.kaggle.com/competitions/spaceship-titanic/overview)) to improve classification accuracy. The additional instructions are provided via a separate file (`examples/spaceship-titanic/README.md`).
120
-
121
-
First, install the requirements for the example environment:
# Example: Optimizing PyTorch Self-Attention with CUDA
2
+
3
+
This example showcases using Weco to optimize a PyTorch causal multi-head self-attention implementation by generating custom [CUDA](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html) kernels. This approach aims for low-level optimization beyond standard PyTorch or even Triton for potentially higher performance on NVIDIA GPUs.
4
+
5
+
This example uses a separate Markdown file (`guide.md`) to provide detailed instructions and context to the LLM.
6
+
7
+
## Setup
8
+
9
+
1. Ensure you are in the `examples/cuda` directory.
10
+
2. Install the required dependency:
11
+
```bash
12
+
pip install torch
13
+
```
14
+
*(Note: This example requires a compatible NVIDIA GPU and the CUDA Toolkit installed on your system for compiling and running the generated CUDA code.)*
15
+
16
+
## Optimization Command
17
+
18
+
Run the following command to start the optimization process:
*`--source optimize.py`: The initial PyTorch self-attention code to be optimized with CUDA.
33
+
*`--eval-command "python evaluate.py --solution-path optimize.py"`: Runs the evaluation script, which compiles (if necessary) and benchmarks the CUDA-enhanced code in`optimize.py` against a baseline, printing the `speedup`.
34
+
*`--metric speedup`: The optimization target metric.
35
+
*`--maximize true`: Weco aims to increase the speedup.
36
+
*`--steps 30`: The number of optimization iterations.
37
+
*`--model gemini-2.5-pro-exp-03-25`: The LLM used for code generation.
38
+
*`--additional-instructions guide.md`: Points Weco to a file containing detailed instructions for the LLM on how to write the CUDA kernels, handle compilation (e.g., using `torch.utils.cpp_extension`), manage data types, and ensure correctness.
39
+
40
+
Weco will iteratively modify `optimize.py`, potentially generating and integrating CUDA C++ code, guided by the evaluation results and the instructions in`guide.md`.
0 commit comments