@@ -76,9 +76,15 @@ def weight=(weight)
7676 end
7777
7878 def group = ( group )
79- c_data = ::FFI ::MemoryPointer . new ( :int , group . size )
80- c_data . write_array_of_int ( group )
81- check_call FFI . XGDMatrixSetUIntInfo ( handle , "group" , c_data , group . size )
79+ c_data = ::FFI ::MemoryPointer . new ( :uint , group . size )
80+ c_data . write_array_of_uint ( group )
81+ interface = {
82+ shape : [ group . length ] ,
83+ typestr : "|u4" ,
84+ data : [ c_data . address , false ] ,
85+ version : 3
86+ }
87+ check_call FFI . XGDMatrixSetInfoFromInterface ( handle , "group" , JSON . generate ( interface ) )
8288 end
8389
8490 def label
@@ -89,6 +95,10 @@ def weight
8995 float_info ( "weight" )
9096 end
9197
98+ def group
99+ uint_info ( "group_ptr" )
100+ end
101+
92102 def num_row
93103 out = ::FFI ::MemoryPointer . new ( :uint64 )
94104 check_call FFI . XGDMatrixNumRow ( handle , out )
@@ -234,7 +244,15 @@ def float_info(field)
234244 out_len = ::FFI ::MemoryPointer . new ( :uint64 )
235245 out_dptr = ::FFI ::MemoryPointer . new ( :float , num_row )
236246 check_call FFI . XGDMatrixGetFloatInfo ( handle , field , out_len , out_dptr )
237- out_dptr . read_pointer . read_array_of_float ( num_row )
247+ out_dptr . read_pointer . null? ? nil : out_dptr . read_pointer . read_array_of_float ( num_row )
248+ end
249+
250+ def uint_info ( field )
251+ num_row ||= num_row ( )
252+ out_len = ::FFI ::MemoryPointer . new ( :uint64 )
253+ out_dptr = ::FFI ::MemoryPointer . new ( :uint , num_row )
254+ check_call FFI . XGDMatrixGetUIntInfo ( handle , field , out_len , out_dptr )
255+ out_dptr . read_pointer . null? ? nil : out_dptr . read_pointer . read_array_of_uint ( num_row )
238256 end
239257
240258 def validate_feature_info ( feature_info , n_features , is_column_split , name )
0 commit comments