@@ -128,7 +128,6 @@ impl DMatrix {
128
128
pub fn from_csr ( indptr : & [ usize ] , indices : & [ usize ] , data : & [ f32 ] , num_cols : Option < usize > ) -> XGBResult < Self > {
129
129
assert_eq ! ( indices. len( ) , data. len( ) ) ;
130
130
let mut handle = ptr:: null_mut ( ) ;
131
- let indptr: Vec < u64 > = indptr. iter ( ) . map ( |x| * x as u64 ) . collect ( ) ;
132
131
let indices: Vec < u32 > = indices. iter ( ) . map ( |x| * x as u32 ) . collect ( ) ;
133
132
let num_cols = num_cols. unwrap_or ( 0 ) ; // infer from data if 0
134
133
xgb_call ! ( xgboost_sys:: XGDMatrixCreateFromCSREx ( indptr. as_ptr( ) ,
@@ -152,7 +151,6 @@ impl DMatrix {
152
151
pub fn from_csc ( indptr : & [ usize ] , indices : & [ usize ] , data : & [ f32 ] , num_rows : Option < usize > ) -> XGBResult < Self > {
153
152
assert_eq ! ( indices. len( ) , data. len( ) ) ;
154
153
let mut handle = ptr:: null_mut ( ) ;
155
- let indptr: Vec < u64 > = indptr. iter ( ) . map ( |x| * x as u64 ) . collect ( ) ;
156
154
let indices: Vec < u32 > = indices. iter ( ) . map ( |x| * x as u32 ) . collect ( ) ;
157
155
let num_rows = num_rows. unwrap_or ( 0 ) ; // infer from data if 0
158
156
xgb_call ! ( xgboost_sys:: XGDMatrixCreateFromCSCEx ( indptr. as_ptr( ) ,
@@ -349,7 +347,7 @@ mod tests {
349
347
350
348
#[ test]
351
349
fn read_num_cols ( ) {
352
- assert_eq ! ( read_train_matrix( ) . unwrap( ) . num_cols( ) , 126 ) ;
350
+ assert_eq ! ( read_train_matrix( ) . unwrap( ) . num_cols( ) , 127 ) ;
353
351
}
354
352
355
353
#[ test]
@@ -380,7 +378,7 @@ mod tests {
380
378
#[ test]
381
379
fn get_set_weights ( ) {
382
380
let mut dmat = read_train_matrix ( ) . unwrap ( ) ;
383
- assert_eq ! ( dmat. get_weights( ) . unwrap( ) , & [ ] ) ;
381
+ assert ! ( dmat. get_weights( ) . unwrap( ) . is_empty ( ) ) ;
384
382
385
383
let weight = [ 1.0 , 10.0 , 44.9555 ] ;
386
384
assert ! ( dmat. set_weights( & weight) . is_ok( ) ) ;
@@ -390,17 +388,20 @@ mod tests {
390
388
#[ test]
391
389
fn get_set_base_margin ( ) {
392
390
let mut dmat = read_train_matrix ( ) . unwrap ( ) ;
393
- assert_eq ! ( dmat. get_base_margin( ) . unwrap( ) , & [ ] ) ;
391
+ assert ! ( dmat. get_base_margin( ) . unwrap( ) . is_empty ( ) ) ;
394
392
395
393
let base_margin = [ 0.00001 , 0.000002 , 1.23 ] ;
394
+ println ! ( "rows: {:?}, {:?}" , dmat. num_rows( ) , base_margin. len( ) ) ;
395
+ let result = dmat. set_base_margin ( & base_margin) ;
396
+ println ! ( "{:?}" , result) ;
396
397
assert ! ( dmat. set_base_margin( & base_margin) . is_ok( ) ) ;
397
398
assert_eq ! ( dmat. get_base_margin( ) . unwrap( ) , base_margin) ;
398
399
}
399
400
400
401
#[ test]
401
402
fn get_set_group ( ) {
402
403
let mut dmat = read_train_matrix ( ) . unwrap ( ) ;
403
- assert_eq ! ( dmat. get_group( ) . unwrap( ) , & [ ] ) ;
404
+ assert ! ( dmat. get_group( ) . unwrap( ) . is_empty ( ) ) ;
404
405
405
406
let group = [ 1 ] ;
406
407
assert ! ( dmat. set_group( & group) . is_ok( ) ) ;
0 commit comments