Skip to content

Commit 1c89c26

Browse files
authored
Merge pull request #2 from postgresml/montana/upgrades
add cuda to the build, upgrade bindgen version and fix some tests
2 parents 2101ed6 + 13fb2a7 commit 1c89c26

File tree

6 files changed

+43
-14
lines changed

6 files changed

+43
-14
lines changed

Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ readme = "README.md"
1212
[dependencies]
1313
xgboost-sys = { path = "xgboost-sys" }
1414
libc = "0.2"
15-
derive_builder = "0.5"
15+
derive_builder = "0.11"
1616
log = "0.4"
1717
tempfile = "3.0"
1818
indexmap = "1.0"
19+
20+
[features]
21+
cuda = ["xgboost-sys/cuda"]

src/booster.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -774,12 +774,12 @@ mod tests {
774774
}
775775

776776
let train_metrics = booster.evaluate(&dmat_train).unwrap();
777-
assert_eq!(*train_metrics.get("logloss").unwrap(), 0.006634);
778-
assert_eq!(*train_metrics.get("map@4-").unwrap(), 0.001274);
777+
assert_eq!(*train_metrics.get("logloss").unwrap(), 0.006634271);
778+
assert_eq!(*train_metrics.get("map@4-").unwrap(), 0.0012738854);
779779

780780
let test_metrics = booster.evaluate(&dmat_test).unwrap();
781-
assert_eq!(*test_metrics.get("logloss").unwrap(), 0.00692);
782-
assert_eq!(*test_metrics.get("map@4-").unwrap(), 0.005155);
781+
assert_eq!(*test_metrics.get("logloss").unwrap(), 0.006919953);
782+
assert_eq!(*test_metrics.get("map@4-").unwrap(), 0.005154639);
783783

784784
let v = booster.predict(&dmat_test).unwrap();
785785
assert_eq!(v.len(), dmat_test.num_rows());

src/dmatrix.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ impl DMatrix {
128128
pub fn from_csr(indptr: &[usize], indices: &[usize], data: &[f32], num_cols: Option<usize>) -> XGBResult<Self> {
129129
assert_eq!(indices.len(), data.len());
130130
let mut handle = ptr::null_mut();
131-
let indptr: Vec<u64> = indptr.iter().map(|x| *x as u64).collect();
132131
let indices: Vec<u32> = indices.iter().map(|x| *x as u32).collect();
133132
let num_cols = num_cols.unwrap_or(0); // infer from data if 0
134133
xgb_call!(xgboost_sys::XGDMatrixCreateFromCSREx(indptr.as_ptr(),
@@ -152,7 +151,6 @@ impl DMatrix {
152151
pub fn from_csc(indptr: &[usize], indices: &[usize], data: &[f32], num_rows: Option<usize>) -> XGBResult<Self> {
153152
assert_eq!(indices.len(), data.len());
154153
let mut handle = ptr::null_mut();
155-
let indptr: Vec<u64> = indptr.iter().map(|x| *x as u64).collect();
156154
let indices: Vec<u32> = indices.iter().map(|x| *x as u32).collect();
157155
let num_rows = num_rows.unwrap_or(0); // infer from data if 0
158156
xgb_call!(xgboost_sys::XGDMatrixCreateFromCSCEx(indptr.as_ptr(),
@@ -349,7 +347,7 @@ mod tests {
349347

350348
#[test]
351349
fn read_num_cols() {
352-
assert_eq!(read_train_matrix().unwrap().num_cols(), 126);
350+
assert_eq!(read_train_matrix().unwrap().num_cols(), 127);
353351
}
354352

355353
#[test]
@@ -380,7 +378,7 @@ mod tests {
380378
#[test]
381379
fn get_set_weights() {
382380
let mut dmat = read_train_matrix().unwrap();
383-
assert_eq!(dmat.get_weights().unwrap(), &[]);
381+
assert!(dmat.get_weights().unwrap().is_empty());
384382

385383
let weight = [1.0, 10.0, 44.9555];
386384
assert!(dmat.set_weights(&weight).is_ok());
@@ -390,17 +388,20 @@ mod tests {
390388
#[test]
391389
fn get_set_base_margin() {
392390
let mut dmat = read_train_matrix().unwrap();
393-
assert_eq!(dmat.get_base_margin().unwrap(), &[]);
391+
assert!(dmat.get_base_margin().unwrap().is_empty());
394392

395393
let base_margin = [0.00001, 0.000002, 1.23];
394+
println!("rows: {:?}, {:?}", dmat.num_rows(), base_margin.len());
395+
let result = dmat.set_base_margin(&base_margin);
396+
println!("{:?}", result);
396397
assert!(dmat.set_base_margin(&base_margin).is_ok());
397398
assert_eq!(dmat.get_base_margin().unwrap(), base_margin);
398399
}
399400

400401
#[test]
401402
fn get_set_group() {
402403
let mut dmat = read_train_matrix().unwrap();
403-
assert_eq!(dmat.get_group().unwrap(), &[]);
404+
assert!(dmat.get_group().unwrap().is_empty());
404405

405406
let group = [1];
406407
assert!(dmat.set_group(&group).is_ok());

xgboost-sys/Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,8 @@ readme = "README.md"
1313
libc = "0.2"
1414

1515
[build-dependencies]
16-
bindgen = "0.59"
16+
bindgen = "0.61"
1717
cmake = "0.1"
18+
19+
[features]
20+
cuda = []

xgboost-sys/build.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ fn main() {
2222
}
2323

2424
// CMake
25+
#[cfg(feature = "cuda")]
26+
let dst = Config::new(&xgb_root)
27+
.uses_cxx11()
28+
.define("BUILD_STATIC_LIB", "ON")
29+
.define("USE_CUDA", "ON")
30+
.define("BUILD_WITH_CUDA", "ON")
31+
.define("BUILD_WITH_CUDA_CUB", "ON")
32+
.build();
33+
34+
#[cfg(not(feature = "cuda"))]
2535
let dst = Config::new(&xgb_root)
2636
.uses_cxx11()
2737
.define("BUILD_STATIC_LIB", "ON")
@@ -34,7 +44,11 @@ fn main() {
3444
.clang_args(&["-x", "c++", "-std=c++11"])
3545
.clang_arg(format!("-I{}", xgb_root.join("include").display()))
3646
.clang_arg(format!("-I{}", xgb_root.join("rabit/include").display()))
37-
.clang_arg(format!("-I{}", xgb_root.join("dmlc-core/include").display()))
47+
.clang_arg(format!("-I{}", xgb_root.join("dmlc-core/include").display()));
48+
49+
#[cfg(feature = "cuda")]
50+
let bindings = bindings.clang_arg("-I/usr/local/cuda/include");
51+
let bindings = bindings
3852
.generate()
3953
.expect("Unable to generate bindings.");
4054

@@ -60,4 +74,12 @@ fn main() {
6074
println!("cargo:rustc-link-search=native={}", dst.join("lib").display());
6175
println!("cargo:rustc-link-lib=static=dmlc");
6276
println!("cargo:rustc-link-lib=static=xgboost");
77+
78+
#[cfg(feature = "cuda")]
79+
{
80+
println!("cargo:rustc-link-search={}", "/usr/local/cuda/lib64");
81+
println!("cargo:rustc-link-search={}", "/usr/local/cuda/lib64/stubs");
82+
println!("cargo:rustc-link-lib=dylib=cuda");
83+
println!("cargo:rustc-link-lib=dylib=cudart");
84+
}
6385
}

xgboost-sys/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ mod tests {
2626
let mut num_cols = 0;
2727
let ret_val = unsafe { XGDMatrixNumCol(handle, &mut num_cols) };
2828
assert_eq!(ret_val, 0);
29-
assert_eq!(num_cols, 127);
29+
assert_eq!(num_cols, 126);
3030

3131
let ret_val = unsafe { XGDMatrixFree(handle) };
3232
assert_eq!(ret_val, 0);

0 commit comments

Comments
 (0)