Skip to content

Commit 87f7144

Browse files
committed
fix dmatrix unit test
1 parent c6e3216 commit 87f7144

File tree

1 file changed

+21
-32
lines changed

1 file changed

+21
-32
lines changed

src/dmatrix.rs

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ use xgboost_sys;
77

88
use super::{XGBResult, XGBError};
99

10-
static KEY_ROOT_INDEX: &'static str = "root_index";
10+
static KEY_GROUP_PTR: &'static str = "group_ptr";
11+
static KEY_GROUP: &'static str = "group";
1112
static KEY_LABEL: &'static str = "label";
1213
static KEY_WEIGHT: &'static str = "weight";
1314
static KEY_BASE_MARGIN: &'static str = "base_margin";
@@ -230,20 +231,6 @@ impl DMatrix {
230231
Ok(DMatrix::new(out_handle)?)
231232
}
232233

233-
/// Gets the specified root index of each instance, can be used for multi task setting.
234-
///
235-
/// See the XGBoost documentation for more information.
236-
pub fn get_root_index(&self) -> XGBResult<&[u32]> {
237-
self.get_uint_info(KEY_ROOT_INDEX)
238-
}
239-
240-
/// Sets the specified root index of each instance, can be used for multi task setting.
241-
///
242-
/// See the XGBoost documentation for more information.
243-
pub fn set_root_index(&mut self, array: &[u32]) -> XGBResult<()> {
244-
self.set_uint_info(KEY_ROOT_INDEX, array)
245-
}
246-
247234
/// Get ground truth labels for each row of this matrix.
248235
pub fn get_labels(&self) -> XGBResult<&[f32]> {
249236
self.get_float_info(KEY_LABEL)
@@ -282,9 +269,20 @@ impl DMatrix {
282269
///
283270
/// See the XGBoost documentation for more information.
284271
pub fn set_group(&mut self, group: &[u32]) -> XGBResult<()> {
285-
xgb_call!(xgboost_sys::XGDMatrixSetGroup(self.handle, group.as_ptr(), group.len() as u64))
272+
// same as xgb_call!(xgboost_sys::XGDMatrixSetGroup(self.handle, group.as_ptr(), group.len() as u64))
273+
self.set_uint_info(KEY_GROUP, group)
274+
}
275+
276+
/// Get the index for the beginning and end of a group.
277+
///
278+
/// Needed when the learning task is ranking.
279+
///
280+
/// See the XGBoost documentation for more information.
281+
pub fn get_group(&self) -> XGBResult<&[u32]> {
282+
self.get_uint_info(KEY_GROUP_PTR)
286283
}
287284

285+
288286
fn get_float_info(&self, field: &str) -> XGBResult<&[f32]> {
289287
let field = ffi::CString::new(field).unwrap();
290288
let mut out_len = 0;
@@ -313,7 +311,6 @@ impl DMatrix {
313311
field.as_ptr(),
314312
&mut out_len,
315313
&mut out_dptr))?;
316-
317314
Ok(unsafe { slice::from_raw_parts(out_dptr as *mut c_uint, out_len as usize) })
318315
}
319316

@@ -370,16 +367,6 @@ mod tests {
370367
// TODO: check contents as well, if possible
371368
}
372369

373-
#[test]
374-
fn get_set_root_index() {
375-
let mut dmat = read_train_matrix().unwrap();
376-
assert_eq!(dmat.get_root_index().unwrap(), &[]);
377-
378-
let root_index = [3, 22, 1];
379-
assert!(dmat.set_root_index(&root_index).is_ok());
380-
assert_eq!(dmat.get_root_index().unwrap(), &[3, 22, 1]);
381-
}
382-
383370
#[test]
384371
fn get_set_labels() {
385372
let mut dmat = read_train_matrix().unwrap();
@@ -395,7 +382,7 @@ mod tests {
395382
let mut dmat = read_train_matrix().unwrap();
396383
assert_eq!(dmat.get_weights().unwrap(), &[]);
397384

398-
let weight = [1.0, 10.0, -123.456789, 44.9555];
385+
let weight = [1.0, 10.0, 44.9555];
399386
assert!(dmat.set_weights(&weight).is_ok());
400387
assert_eq!(dmat.get_weights().unwrap(), weight);
401388
}
@@ -411,11 +398,13 @@ mod tests {
411398
}
412399

413400
#[test]
414-
fn set_group() {
401+
fn get_set_group() {
415402
let mut dmat = read_train_matrix().unwrap();
403+
assert_eq!(dmat.get_group().unwrap(), &[]);
416404

417-
let group = [1, 2, 3];
405+
let group = [1];
418406
assert!(dmat.set_group(&group).is_ok());
407+
assert_eq!(dmat.get_group().unwrap(), &[0, 1]);
419408
}
420409

421410
#[test]
@@ -426,7 +415,7 @@ mod tests {
426415

427416
let dmat = DMatrix::from_csr(&indptr, &indices, &data, None).unwrap();
428417
assert_eq!(dmat.num_rows(), 4);
429-
assert_eq!(dmat.num_cols(), 3);
418+
assert_eq!(dmat.num_cols(), 0); // https://github.com/dmlc/xgboost/pull/7265
430419

431420
let dmat = DMatrix::from_csr(&indptr, &indices, &data, Some(10)).unwrap();
432421
assert_eq!(dmat.num_rows(), 4);
@@ -477,7 +466,7 @@ mod tests {
477466
assert_eq!(dmat.slice(&[1]).unwrap().shape(), (1, 2));
478467
assert_eq!(dmat.slice(&[0, 1]).unwrap().shape(), (2, 2));
479468
assert_eq!(dmat.slice(&[3, 2, 1]).unwrap().shape(), (3, 2));
480-
assert!(dmat.slice(&[10, 11, 12]).is_err());
469+
assert_eq!(dmat.slice(&[10, 11, 12]).unwrap().shape(), (0, 0));
481470
}
482471

483472
#[test]

0 commit comments

Comments
 (0)