@@ -7,7 +7,8 @@ use xgboost_sys;
7
7
8
8
use super :: { XGBResult , XGBError } ;
9
9
10
- static KEY_ROOT_INDEX : & ' static str = "root_index" ;
10
+ static KEY_GROUP_PTR : & ' static str = "group_ptr" ;
11
+ static KEY_GROUP : & ' static str = "group" ;
11
12
static KEY_LABEL : & ' static str = "label" ;
12
13
static KEY_WEIGHT : & ' static str = "weight" ;
13
14
static KEY_BASE_MARGIN : & ' static str = "base_margin" ;
@@ -230,20 +231,6 @@ impl DMatrix {
230
231
Ok ( DMatrix :: new ( out_handle) ?)
231
232
}
232
233
233
- /// Gets the specified root index of each instance, can be used for multi task setting.
234
- ///
235
- /// See the XGBoost documentation for more information.
236
- pub fn get_root_index ( & self ) -> XGBResult < & [ u32 ] > {
237
- self . get_uint_info ( KEY_ROOT_INDEX )
238
- }
239
-
240
- /// Sets the specified root index of each instance, can be used for multi task setting.
241
- ///
242
- /// See the XGBoost documentation for more information.
243
- pub fn set_root_index ( & mut self , array : & [ u32 ] ) -> XGBResult < ( ) > {
244
- self . set_uint_info ( KEY_ROOT_INDEX , array)
245
- }
246
-
247
234
/// Get ground truth labels for each row of this matrix.
248
235
pub fn get_labels ( & self ) -> XGBResult < & [ f32 ] > {
249
236
self . get_float_info ( KEY_LABEL )
@@ -282,9 +269,20 @@ impl DMatrix {
282
269
///
283
270
/// See the XGBoost documentation for more information.
284
271
pub fn set_group ( & mut self , group : & [ u32 ] ) -> XGBResult < ( ) > {
285
- xgb_call ! ( xgboost_sys:: XGDMatrixSetGroup ( self . handle, group. as_ptr( ) , group. len( ) as u64 ) )
272
+ // same as xgb_call!(xgboost_sys::XGDMatrixSetGroup(self.handle, group.as_ptr(), group.len() as u64))
273
+ self . set_uint_info ( KEY_GROUP , group)
274
+ }
275
+
276
+ /// Get the index for the beginning and end of a group.
277
+ ///
278
+ /// Needed when the learning task is ranking.
279
+ ///
280
+ /// See the XGBoost documentation for more information.
281
+ pub fn get_group ( & self ) -> XGBResult < & [ u32 ] > {
282
+ self . get_uint_info ( KEY_GROUP_PTR )
286
283
}
287
284
285
+
288
286
fn get_float_info ( & self , field : & str ) -> XGBResult < & [ f32 ] > {
289
287
let field = ffi:: CString :: new ( field) . unwrap ( ) ;
290
288
let mut out_len = 0 ;
@@ -313,7 +311,6 @@ impl DMatrix {
313
311
field. as_ptr( ) ,
314
312
& mut out_len,
315
313
& mut out_dptr) ) ?;
316
-
317
314
Ok ( unsafe { slice:: from_raw_parts ( out_dptr as * mut c_uint , out_len as usize ) } )
318
315
}
319
316
@@ -370,16 +367,6 @@ mod tests {
370
367
// TODO: check contents as well, if possible
371
368
}
372
369
373
- #[ test]
374
- fn get_set_root_index ( ) {
375
- let mut dmat = read_train_matrix ( ) . unwrap ( ) ;
376
- assert_eq ! ( dmat. get_root_index( ) . unwrap( ) , & [ ] ) ;
377
-
378
- let root_index = [ 3 , 22 , 1 ] ;
379
- assert ! ( dmat. set_root_index( & root_index) . is_ok( ) ) ;
380
- assert_eq ! ( dmat. get_root_index( ) . unwrap( ) , & [ 3 , 22 , 1 ] ) ;
381
- }
382
-
383
370
#[ test]
384
371
fn get_set_labels ( ) {
385
372
let mut dmat = read_train_matrix ( ) . unwrap ( ) ;
@@ -395,7 +382,7 @@ mod tests {
395
382
let mut dmat = read_train_matrix ( ) . unwrap ( ) ;
396
383
assert_eq ! ( dmat. get_weights( ) . unwrap( ) , & [ ] ) ;
397
384
398
- let weight = [ 1.0 , 10.0 , - 123.456789 , 44.9555 ] ;
385
+ let weight = [ 1.0 , 10.0 , 44.9555 ] ;
399
386
assert ! ( dmat. set_weights( & weight) . is_ok( ) ) ;
400
387
assert_eq ! ( dmat. get_weights( ) . unwrap( ) , weight) ;
401
388
}
@@ -411,11 +398,13 @@ mod tests {
411
398
}
412
399
413
400
#[ test]
414
- fn set_group ( ) {
401
+ fn get_set_group ( ) {
415
402
let mut dmat = read_train_matrix ( ) . unwrap ( ) ;
403
+ assert_eq ! ( dmat. get_group( ) . unwrap( ) , & [ ] ) ;
416
404
417
- let group = [ 1 , 2 , 3 ] ;
405
+ let group = [ 1 ] ;
418
406
assert ! ( dmat. set_group( & group) . is_ok( ) ) ;
407
+ assert_eq ! ( dmat. get_group( ) . unwrap( ) , & [ 0 , 1 ] ) ;
419
408
}
420
409
421
410
#[ test]
@@ -426,7 +415,7 @@ mod tests {
426
415
427
416
let dmat = DMatrix :: from_csr ( & indptr, & indices, & data, None ) . unwrap ( ) ;
428
417
assert_eq ! ( dmat. num_rows( ) , 4 ) ;
429
- assert_eq ! ( dmat. num_cols( ) , 3 ) ;
418
+ assert_eq ! ( dmat. num_cols( ) , 0 ) ; // https://github.com/dmlc/xgboost/pull/7265
430
419
431
420
let dmat = DMatrix :: from_csr ( & indptr, & indices, & data, Some ( 10 ) ) . unwrap ( ) ;
432
421
assert_eq ! ( dmat. num_rows( ) , 4 ) ;
@@ -477,7 +466,7 @@ mod tests {
477
466
assert_eq ! ( dmat. slice( & [ 1 ] ) . unwrap( ) . shape( ) , ( 1 , 2 ) ) ;
478
467
assert_eq ! ( dmat. slice( & [ 0 , 1 ] ) . unwrap( ) . shape( ) , ( 2 , 2 ) ) ;
479
468
assert_eq ! ( dmat. slice( & [ 3 , 2 , 1 ] ) . unwrap( ) . shape( ) , ( 3 , 2 ) ) ;
480
- assert ! ( dmat. slice( & [ 10 , 11 , 12 ] ) . is_err ( ) ) ;
469
+ assert_eq ! ( dmat. slice( & [ 10 , 11 , 12 ] ) . unwrap ( ) . shape ( ) , ( 0 , 0 ) ) ;
481
470
}
482
471
483
472
#[ test]
0 commit comments