Skip to content

Commit 923528e

Browse files
committed
Improved error message for invalid arrays
1 parent a15c129 commit 923528e

File tree

5 files changed

+26
-0
lines changed

5 files changed

+26
-0
lines changed

lib/lightgbm/dataset.rb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ def construct
166166
ncol = data.first.count
167167
flat_data = data.flat_map { |v| v.fetch_values(*keys) }
168168
else
169+
data = data.to_a
170+
check_2d_array(data)
169171
nrow = data.count
170172
ncol = data.first.count
171173
flat_data = data.flatten

lib/lightgbm/inner_predictor.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def predict(data, start_iteration: 0, num_iteration: -1, raw_score: false, pred_
4949
data = data.to_a
5050
singular = !data.first.is_a?(Array)
5151
data = [data] if singular
52+
check_2d_array(data)
5253
end
5354

5455
preds, nrow =

lib/lightgbm/utils.rb

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ def set_verbosity(params)
2424
end
2525
end
2626

27+
def check_2d_array(data)
28+
ncol = data.first&.size || 0
29+
if !data.all? { |r| r.size == ncol }
30+
raise ArgumentError, "Rows have different sizes"
31+
end
32+
end
33+
2734
# for categorical, NaN and negative value are the same
2835
def handle_missing(data)
2936
data.map! { |v| v.nil? ? Float::NAN : v }

test/booster_test.rb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ def test_predict_rover
8888
end
8989
end
9090

91+
def test_predict_array_different_sizes
92+
x_test = [[1, 2], [3, 4, 5]]
93+
error = assert_raises(ArgumentError) do
94+
booster.predict(x_test)
95+
end
96+
assert_equal "Rows have different sizes", error.message
97+
end
98+
9199
def test_predict_raw_score
92100
x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]]
93101
expected = [0.9823112229173586, 0.9583143724610858]

test/dataset_test.rb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,14 @@ def test_rover
113113
assert_equal ["x0", "x1", "x2", "x3"], dataset.feature_name
114114
end
115115

116+
def test_array_different_sizes
117+
data = [[1, 2], [3, 4, 5]]
118+
error = assert_raises(ArgumentError) do
119+
LightGBM::Dataset.new(data)
120+
end
121+
assert_equal "Rows have different sizes", error.message
122+
end
123+
116124
def test_copy
117125
regression_train.dup
118126
regression_train.clone

0 commit comments

Comments
 (0)