Skip to content

Commit 2c30708

Browse files
committed
Fix the build scripts with new static linking (could become even more
static).
1 parent 05d87d3 commit 2c30708

File tree

2 files changed

+43
-13
lines changed

2 files changed

+43
-13
lines changed

candle-extensions/candle-flash-attn-v1/build.rs

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,39 @@
33
// variable in order to cache the compiled artifacts and avoid recompiling too often.
44
use anyhow::{Context, Result};
55
use rayon::prelude::*;
6+
use std::fs;
67
use std::path::PathBuf;
78
use std::str::FromStr;
89

9-
const KERNEL_FILES: [&str; 4] = [
10-
"flash_api.cu",
11-
"fmha_fwd_hdim32.cu",
12-
"fmha_fwd_hdim64.cu",
13-
"fmha_fwd_hdim128.cu",
14-
];
10+
// const KERNEL_FILES: [&str; 4] = [
11+
// "flash_api.cu",
12+
// "fmha_fwd_hdim32.cu",
13+
// "fmha_fwd_hdim64.cu",
14+
// "fmha_fwd_hdim128.cu",
15+
// ];
16+
17+
/// Recursively reads the filenames in a directory and stores them in a Vec.
18+
fn _read_dir_recursively(dir_path: &PathBuf, paths: &mut Vec<PathBuf>) -> std::io::Result<()> {
19+
for entry in fs::read_dir(dir_path)? {
20+
let entry = entry?;
21+
let path = entry.path();
22+
23+
if path.is_dir() {
24+
_read_dir_recursively(&path, paths)?;
25+
} else {
26+
paths.push(path);
27+
}
28+
}
29+
30+
Ok(())
31+
}
32+
33+
/// Recursively reads the filenames in a directory and stores them in a Vec.
34+
fn read_dir_recursively(dir_path: &PathBuf) -> std::io::Result<Vec<PathBuf>> {
35+
let mut paths = Vec::new();
36+
_read_dir_recursively(dir_path, &mut paths)?;
37+
Ok(paths)
38+
}
1539

1640
fn main() -> Result<()> {
1741
let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
@@ -25,12 +49,11 @@ fn main() -> Result<()> {
2549
.unwrap();
2650

2751
println!("cargo:rerun-if-changed=build.rs");
28-
for kernel_file in KERNEL_FILES.iter() {
29-
println!("cargo:rerun-if-changed=kernels/{kernel_file}");
52+
53+
let paths = read_dir_recursively(&PathBuf::from_str("kernels")?)?;
54+
for file in paths.iter() {
55+
println!("cargo:rerun-if-changed={}", file.display());
3056
}
31-
println!("cargo:rerun-if-changed=kernels/**.h");
32-
println!("cargo:rerun-if-changed=kernels/**.cuh");
33-
println!("cargo:rerun-if-changed=kernels/fmha/**.h");
3457
let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?);
3558
let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
3659
Err(_) =>
@@ -57,12 +80,17 @@ fn main() -> Result<()> {
5780
let out_file = build_dir.join("libflashattentionv1.a");
5881

5982
let kernel_dir = PathBuf::from("kernels");
60-
let cu_files: Vec<_> = KERNEL_FILES
83+
let kernels: Vec<_> = paths
84+
.iter()
85+
.filter(|f| f.extension().map(|ext| ext == "cu").unwrap_or_default())
86+
.collect();
87+
let cu_files: Vec<_> = kernels
6188
.iter()
6289
.map(|f| {
6390
let mut obj_file = out_dir.join(f);
91+
fs::create_dir_all(obj_file.parent().unwrap()).unwrap();
6492
obj_file.set_extension("o");
65-
(kernel_dir.join(f), obj_file)
93+
(f, obj_file)
6694
})
6795
.collect();
6896
let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());

candle-extensions/candle-layer-norm/build.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ fn set_cuda_include_dir() -> Result<()> {
176176
.chain(roots)
177177
.find(|path| path.join("include").join("cuda.h").is_file())
178178
.context("cannot find include/cuda.h")?;
179+
println!("cargo:rustc-link-search={}", root.join("lib").display());
180+
println!("cargo:rustc-link-search={}", root.join("lib64").display());
179181
println!(
180182
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
181183
root.join("include").display()

0 commit comments

Comments
 (0)