Skip to content

Commit 63b0a65

Browse files
authored
patch for #322 - error if cat levels > nbins (#323)
1 parent 72c4839 commit 63b0a65

File tree

3 files changed

+27
-2
lines changed

3 files changed

+27
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EvoTrees"
22
uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
33
authors = ["jeremiedb <jeremie.db@evovest.com>"]
4-
version = "0.18.4"
4+
version = "0.18.5"
55

66
[deps]
77
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using CategoricalArrays
2+
using DataFrames
3+
using EvoTrees
4+
5+
nobs = 10_000
6+
nfeats = 3
7+
nlevels = 16
8+
nbins = 16
9+
10+
df = DataFrame(rand(nobs, nfeats), :auto)
11+
df.cat = rand(1:nlevels, nobs) |> categorical
12+
df.y = randn(nobs)
13+
length(unique(df.cat))
14+
target_name="y"
15+
feature_names = setdiff(names(df), [target_name])
16+
17+
config = EvoTreeRegressor(; nbins)
18+
19+
EvoTrees.fit(
20+
config,
21+
df;
22+
target_name,
23+
feature_names,
24+
)

src/fit-utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ function get_edges(df; feature_names, nbins, rng=Random.MersenneTwister(), kwarg
4141
edges[j] = levels(col)
4242
featbins[j] = length(edges[j])
4343
feattypes[j] = isordered(col) ? true : false
44-
@assert featbins[j] <= 255 "Max categorical levels currently limited to 255, $(feature_names[j]) has $(featbins[j])."
44+
featbins[j] <= nbins || error("
45+
Max categorical levels is limited to `nbins` ($nbins). Feature $(feature_names[j]) has $(featbins[j]) levels. Consider using larger `nbins`, up to 255.")
4546
elseif eltype(col) <: Real
4647
edges[j] = unique(quantile(col, (1:nbins-1) / nbins))
4748
featbins[j] = length(edges[j]) + 1

0 commit comments

Comments
 (0)