Skip to content

Commit f293d3a

Browse files
committed
Fixed deprecation warning
1 parent 85bc90a commit f293d3a

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

lib/xgboost/dmatrix.rb

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,15 @@ def weight=(weight)
7676
end
7777

7878
def group=(group)
79-
c_data = ::FFI::MemoryPointer.new(:int, group.size)
80-
c_data.write_array_of_int(group)
81-
check_call FFI.XGDMatrixSetUIntInfo(handle, "group", c_data, group.size)
79+
c_data = ::FFI::MemoryPointer.new(:uint, group.size)
80+
c_data.write_array_of_uint(group)
81+
interface = {
82+
shape: [group.length],
83+
typestr: "|u4",
84+
data: [c_data.address, false],
85+
version: 3
86+
}
87+
check_call FFI.XGDMatrixSetInfoFromInterface(handle, "group", JSON.generate(interface))
8288
end
8389

8490
def label
@@ -89,6 +95,10 @@ def weight
8995
float_info("weight")
9096
end
9197

98+
def group
99+
uint_info("group_ptr")
100+
end
101+
92102
def num_row
93103
out = ::FFI::MemoryPointer.new(:uint64)
94104
check_call FFI.XGDMatrixNumRow(handle, out)
@@ -234,7 +244,15 @@ def float_info(field)
234244
out_len = ::FFI::MemoryPointer.new(:uint64)
235245
out_dptr = ::FFI::MemoryPointer.new(:float, num_row)
236246
check_call FFI.XGDMatrixGetFloatInfo(handle, field, out_len, out_dptr)
237-
out_dptr.read_pointer.read_array_of_float(num_row)
247+
out_dptr.read_pointer.null? ? nil : out_dptr.read_pointer.read_array_of_float(num_row)
248+
end
249+
250+
def uint_info(field)
251+
num_row ||= num_row()
252+
out_len = ::FFI::MemoryPointer.new(:uint64)
253+
out_dptr = ::FFI::MemoryPointer.new(:uint, num_row)
254+
check_call FFI.XGDMatrixGetUIntInfo(handle, field, out_len, out_dptr)
255+
out_dptr.read_pointer.null? ? nil : out_dptr.read_pointer.read_array_of_uint(num_row)
238256
end
239257

240258
def validate_feature_info(feature_info, n_features, is_column_split, name)

lib/xgboost/ffi.rb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ module FFI
2121

2222
# dmatrix
2323
attach_function :XGDMatrixCreateFromMat, %i[pointer uint64 uint64 float pointer], :int
24-
attach_function :XGDMatrixSetUIntInfo, %i[pointer string pointer uint64], :int
24+
attach_function :XGDMatrixSetInfoFromInterface, %i[pointer string string], :int
2525
attach_function :XGDMatrixSetStrFeatureInfo, %i[pointer string pointer uint64], :int
2626
attach_function :XGDMatrixGetStrFeatureInfo, %i[pointer string pointer pointer], :int
2727
attach_function :XGDMatrixNumRow, %i[pointer pointer], :int
@@ -33,6 +33,7 @@ module FFI
3333
attach_function :XGDMatrixSaveBinary, %i[pointer string int], :int
3434
attach_function :XGDMatrixSetFloatInfo, %i[pointer string pointer uint64], :int
3535
attach_function :XGDMatrixGetFloatInfo, %i[pointer string pointer pointer], :int
36+
attach_function :XGDMatrixGetUIntInfo, %i[pointer string pointer pointer], :int
3637

3738
# booster
3839
attach_function :XGBoosterCreate, %i[pointer int pointer], :int

test/dmatrix_test.rb

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,21 @@ def test_label
55
data = [[1, 2], [3, 4]]
66
label = [1, 2]
77
dataset = XGBoost::DMatrix.new(data, label: label)
8-
assert label, dataset.label
8+
assert_equal label, dataset.label
99
end
1010

1111
def test_weight
1212
data = [[1, 2], [3, 4]]
1313
weight = [1, 2]
1414
dataset = XGBoost::DMatrix.new(data, weight: weight)
15-
assert weight, dataset.weight
15+
assert_equal weight, dataset.weight
16+
end
17+
18+
def test_group
19+
data = [[1, 2], [3, 4]]
20+
dataset = XGBoost::DMatrix.new(data)
21+
dataset.group = [1, 2]
22+
assert_equal [0, 1], dataset.group
1623
end
1724

1825
def test_feature_names_and_types

0 commit comments

Comments
 (0)