Skip to content

Commit be4612f

Browse files
committed
update dmatrix
1 parent 1e35b99 commit be4612f

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

examples/basic/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ publish = false
66

77
[dependencies]
88
xgboost = { path = "../../" }
9-
sprs = "0.6"
9+
sprs = "0.11"
1010
log = "0.4"
1111
env_logger = "0.5"

examples/basic/src/main.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ fn main() {
9797
let mut data = Vec::new();
9898

9999
let reader = BufReader::new(File::open("../../xgboost-sys/xgboost/demo/data/agaricus.txt.train").unwrap());
100-
let mut current_row: u64 = 0;
100+
let mut current_row = 0;
101101
for line in reader.lines() {
102102
let line = line.unwrap();
103103
let sample: Vec<&str> = line.split_whitespace().collect();
@@ -106,7 +106,7 @@ fn main() {
106106
for entry in &sample[1..] {
107107
let pair: Vec<&str> = entry.split(':').collect();
108108
rows.push(current_row);
109-
cols.push(pair[0].parse::<u64>().unwrap());
109+
cols.push(pair[0].parse::<usize>().unwrap());
110110
data.push(pair[1].parse::<f32>().unwrap());
111111
}
112112

@@ -116,11 +116,12 @@ fn main() {
116116
// work out size of sparse matrix from max row/col values
117117
let shape = ((*rows.iter().max().unwrap() + 1) as usize,
118118
(*cols.iter().max().unwrap() + 1) as usize);
119+
let num_col = Some((*cols.iter().max().unwrap() + 1) as usize);
119120
let triplet_mat = sprs::TriMatBase::from_triplets(shape, rows, cols, data);
120121
let csr_mat = triplet_mat.to_csr();
121122

122123
let indices: Vec<usize> = csr_mat.indices().into_iter().map(|i| *i as usize).collect();
123-
let mut dtrain = DMatrix::from_csr(csr_mat.indptr(), &indices, csr_mat.data(), None).unwrap();
124+
let mut dtrain = DMatrix::from_csr(csr_mat.indptr().raw_storage(), &indices, csr_mat.data(), num_col).unwrap();
124125
dtrain.set_labels(&labels).unwrap();
125126

126127
let training_params = parameters::TrainingParametersBuilder::default().dtrain(&dtrain).build().unwrap();

src/dmatrix.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,10 @@ impl DMatrix {
124124
/// `data[indptr[i]:indptr[i+1]`.
125125
///
126126
/// If `num_cols` is set to None, number of columns will be inferred from given data.
127-
pub fn from_csr(indptr: &[u64], indices: &[usize], data: &[f32], num_cols: Option<usize>) -> XGBResult<Self> {
127+
pub fn from_csr(indptr: &[usize], indices: &[usize], data: &[f32], num_cols: Option<usize>) -> XGBResult<Self> {
128128
assert_eq!(indices.len(), data.len());
129129
let mut handle = ptr::null_mut();
130+
let indptr: Vec<u64> = indptr.iter().map(|x| *x as u64).collect();
130131
let indices: Vec<u32> = indices.iter().map(|x| *x as u32).collect();
131132
let num_cols = num_cols.unwrap_or(0); // infer from data if 0
132133
xgb_call!(xgboost_sys::XGDMatrixCreateFromCSREx(indptr.as_ptr(),
@@ -147,9 +148,10 @@ impl DMatrix {
147148
/// `data[indptr[i]:indptr[i+1]`.
148149
///
149150
/// If `num_rows` is set to None, number of rows will be inferred from given data.
150-
pub fn from_csc(indptr: &[u64], indices: &[usize], data: &[f32], num_rows: Option<usize>) -> XGBResult<Self> {
151+
pub fn from_csc(indptr: &[usize], indices: &[usize], data: &[f32], num_rows: Option<usize>) -> XGBResult<Self> {
151152
assert_eq!(indices.len(), data.len());
152153
let mut handle = ptr::null_mut();
154+
let indptr: Vec<u64> = indptr.iter().map(|x| *x as u64).collect();
153155
let indices: Vec<u32> = indices.iter().map(|x| *x as u32).collect();
154156
let num_rows = num_rows.unwrap_or(0); // infer from data if 0
155157
xgb_call!(xgboost_sys::XGDMatrixCreateFromCSCEx(indptr.as_ptr(),

0 commit comments

Comments
 (0)