Skip to content

Commit e32683a

Browse files
committed
Added support for feature_names: "auto" to Dataset
1 parent 8684948 commit e32683a

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## 0.4.0 (unreleased)
22

33
- Added support for hashes to `predict` method
4+
- Added support for `feature_names: "auto"` to `Dataset`
45
- Dropped support for Ruby < 3.1
56

67
## 0.3.4 (2024-07-28)

lib/lightgbm/dataset.rb

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,18 @@ def construct
142142
ncol = data.column_count
143143
flat_data = data.to_a.flatten
144144
elsif daru?(data)
145+
if @feature_names == "auto"
146+
@feature_names = data.vectors.to_a
147+
end
145148
nrow, ncol = data.shape
146149
flat_data = data.map_rows(&:to_a).flatten
147-
elsif numo?(data) || rover?(data)
148-
data = data.to_numo if rover?(data)
150+
elsif numo?(data)
151+
nrow, ncol = data.shape
152+
elsif rover?(data)
153+
if @feature_names == "auto"
154+
@feature_names = data.keys
155+
end
156+
data = data.to_numo
149157
nrow, ncol = data.shape
150158
else
151159
nrow = data.count

test/dataset_test.rb

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,19 @@ def test_dump_text
5252
def test_matrix
5353
data = Matrix.build(3, 3) { |row, col| row + col }
5454
label = Vector.elements([4, 5, 6])
55-
LightGBM::Dataset.new(data, label: label)
55+
dataset = LightGBM::Dataset.new(data, label: label)
56+
assert_equal ["Column_0", "Column_1", "Column_2"], dataset.feature_names
5657
end
5758

5859
def test_daru
5960
data = Daru::DataFrame.from_csv(data_path)
6061
label = data["y"]
6162
data = data.delete_vector("y")
62-
LightGBM::Dataset.new(data, label: label)
63+
dataset = LightGBM::Dataset.new(data, label: label)
64+
assert_equal ["Column_0", "Column_1", "Column_2", "Column_3"], dataset.feature_names
65+
66+
dataset = LightGBM::Dataset.new(data, label: label, feature_names: "auto")
67+
assert_equal ["x0", "x1", "x2", "x3"], dataset.feature_names
6368
end
6469

6570
def test_numo
@@ -68,7 +73,8 @@ def test_numo
6873
require "numo/narray"
6974
data = Numo::DFloat.new(3, 5).seq
7075
label = Numo::DFloat.new(3).seq
71-
LightGBM::Dataset.new(data, label: label)
76+
dataset = LightGBM::Dataset.new(data, label: label)
77+
assert_equal ["Column_0", "Column_1", "Column_2", "Column_3", "Column_4"], dataset.feature_names
7278
end
7379

7480
def test_rover
@@ -77,7 +83,11 @@ def test_rover
7783
require "rover"
7884
data = Rover.read_csv(data_path)
7985
label = data.delete("y")
80-
LightGBM::Dataset.new(data, label: label)
86+
dataset = LightGBM::Dataset.new(data, label: label)
87+
assert_equal ["Column_0", "Column_1", "Column_2", "Column_3"], dataset.feature_names
88+
89+
dataset = LightGBM::Dataset.new(data, label: label, feature_names: "auto")
90+
assert_equal ["x0", "x1", "x2", "x3"], dataset.feature_names
8191
end
8292

8393
def test_copy

0 commit comments

Comments
 (0)