Skip to content

Commit c3bb5bf

Browse files
Use cudaforge for kernel build (#3346)
* Use cudaforge for kernel build * Fix clippy * Update cudaforge to v0.1.2 * Fix build candle-examples --------- Co-authored-by: Eric Buehler <ericlbuehler@gmail.com>
1 parent 971e7ed commit c3bb5bf

File tree

13 files changed

+116
-437
lines changed

13 files changed

+116
-437
lines changed

Cargo.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ members = [
1313
]
1414
exclude = [
1515
"candle-book",
16-
"candle-flash-attn-build",
1716
"candle-flash-attn",
1817
"candle-flash-attn-v3",
1918
"candle-kernels",
@@ -38,7 +37,6 @@ anyhow = { version = "1", features = ["backtrace"] }
3837
byteorder = "1.4.3"
3938
candle = { path = "./candle-core", package = "candle-core", version = "0.9.2" }
4039
candle-datasets = { path = "./candle-datasets", version = "0.9.2" }
41-
candle-flash-attn-build = { path = "candle-flash-attn-build", version = "0.9.2" }
4240
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.2" }
4341
candle-flash-attn-v3 = { path = "./candle-flash-attn-v3", version = "0.9.2" }
4442
candle-kernels = { path = "./candle-kernels", version = "0.9.2" }

README.md

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -368,20 +368,34 @@ conditions](https://huggingface.co/meta-llama/Llama-2-7b-hf), and set up your
368368
authentication token. See issue
369369
[#350](https://github.com/huggingface/candle/issues/350) for more details.
370370

371-
#### Missing cute/cutlass headers when compiling flash-attn
371+
#### Docker build
372+
373+
When building CUDA kernels inside a Dockerfile, nvidia-smi cannot be used to auto-detect compute capability.
374+
375+
You must explicitly set CUDA_COMPUTE_CAP, for example:
372376

373377
```
374-
In file included from kernels/flash_fwd_launch_template.h:11:0,
375-
from kernels/flash_fwd_hdim224_fp16_sm80.cu:5:
376-
kernels/flash_fwd_kernel.h:8:10: fatal error: cute/algorithm/copy.hpp: No such file or directory
377-
#include <cute/algorithm/copy.hpp>
378-
^~~~~~~~~~~~~~~~~~~~~~~~~
379-
compilation terminated.
380-
Error: nvcc error while compiling:
381-
```
382-
[cutlass](https://github.com/NVIDIA/cutlass) is provided as a git submodule so you may want to run the following command to check it in properly.
383-
```bash
384-
git submodule update --init
378+
FROM nvidia/cuda:12.9.0-devel-ubuntu22.04
379+
380+
# Install git and curl
381+
RUN set -eux; \
382+
apt-get update; \
383+
apt-get install -y curl git ca-certificates;
384+
385+
# Install Rust
386+
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
387+
388+
# Clone candle repo
389+
RUN git clone https://github.com/huggingface/candle.git
390+
391+
# Set compute capability for the build
392+
ARG CUDA_COMPUTE_CAP=90
393+
ENV CUDA_COMPUTE_CAP=${CUDA_COMPUTE_CAP}
394+
395+
# Build with explicit compute cap
396+
WORKDIR /app
397+
COPY . .
398+
RUN cargo build --release features cuda
385399
```
386400

387401
#### Compiling with flash-attention fails

candle-examples/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ tokio = "1.48.0"
6060

6161
[build-dependencies]
6262
anyhow = { workspace = true }
63-
bindgen_cuda = { version = "0.1.5", optional = true }
63+
cudaforge = { version = "0.1.2", optional = true }
6464
hf-hub = { workspace = true, features = ["tokio"] }
6565

6666
[features]
@@ -75,7 +75,7 @@ cuda = [
7575
"candle/cuda",
7676
"candle-nn/cuda",
7777
"candle-transformers/cuda",
78-
"dep:bindgen_cuda",
78+
"dep:cudaforge",
7979
]
8080
cudnn = ["candle/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"]
8181
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]

candle-examples/build.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,42 @@
11
#![allow(unused)]
2-
use anyhow::{Context, Result};
3-
use std::env;
4-
use std::io::Write;
5-
use std::path::{Path, PathBuf};
62
mod buildtime_downloader;
73
use buildtime_downloader::download_model;
84

95
struct KernelDirectories {
106
kernel_glob: &'static str,
117
rust_target: &'static str,
12-
include_dirs: &'static [&'static str],
138
}
149

1510
const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories {
1611
kernel_glob: "examples/custom-ops/kernels/*.cu",
1712
rust_target: "examples/custom-ops/cuda_kernels.rs",
18-
include_dirs: &[],
1913
}];
2014

21-
fn main() -> Result<()> {
15+
fn main() {
2216
println!("cargo::rerun-if-changed=build.rs");
2317

2418
#[cfg(feature = "cuda")]
2519
{
20+
use std::env;
21+
use std::path::{Path, PathBuf};
2622
// Added: Get the safe output directory from the environment.
2723
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
2824

2925
for kdir in KERNEL_DIRS.iter() {
30-
let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob);
31-
let bindings = builder.build_ptx().unwrap();
32-
3326
// Changed: This now writes to a safe path inside $OUT_DIR.
3427
let safe_target = out_dir.join(
3528
Path::new(kdir.rust_target)
3629
.file_name()
37-
.context("Failed to get filename from rust_target")?,
30+
.expect("Failed to get filename from rust_target"),
3831
);
39-
bindings.write(safe_target).unwrap()
32+
33+
let bindings = cudaforge::KernelBuilder::new()
34+
.source_glob(kdir.kernel_glob)
35+
.build_ptx()
36+
.expect("Failed to build ptx");
37+
bindings
38+
.write(safe_target)
39+
.expect("Failed to write ptx bindings");
4040
}
4141
}
4242

@@ -45,7 +45,6 @@ fn main() -> Result<()> {
4545
// Example value:
4646
// CANDLE_BUILDTIME_MODEL_REVISION="sentence-transformers/all-MiniLM-L6-v2:c9745ed1d9f207416be6d2e6f8de32d1f16199bf"
4747
if let Some(model_rev) = core::option_env!("CANDLE_BUILDTIME_MODEL_REVISION") {
48-
buildtime_downloader::download_model(model_rev)?;
48+
buildtime_downloader::download_model(model_rev).expect("Model download failed!");
4949
}
50-
Ok(())
5150
}

candle-examples/buildtime_downloader.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
1-
use anyhow::{Context, Result};
1+
use anyhow::Result;
22
use hf_hub::{api::sync::Api, Repo, RepoType};
3-
use std::{
4-
fs::{self, File},
5-
io::copy,
6-
path::Path,
7-
};
83

94
pub fn download_model(model_and_revision: &str) -> Result<()> {
105
let (model_id, revision) = match model_and_revision.split_once(":") {

candle-flash-attn-build/Cargo.toml

Lines changed: 0 additions & 10 deletions
This file was deleted.

candle-flash-attn-build/src/lib.rs

Lines changed: 0 additions & 102 deletions
This file was deleted.

candle-flash-attn-v3/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ half = { version = "2.3.1", features = ["num-traits"] }
1919
anyhow = { version = "1", features = ["backtrace"] }
2020
num_cpus = "1.15.0"
2121
rayon = "1.7.0"
22-
candle-flash-attn-build = { path = "../candle-flash-attn-build", version = "0.9.2" }
22+
cudaforge = "0.1"
2323

2424
[dev-dependencies]
2525
anyhow = { version = "1", features = ["backtrace"] }

0 commit comments

Comments
 (0)