Skip to content

Commit ee51728

Browse files
committed
Fix feature names to accommodate classes
1 parent 6be7f8c commit ee51728

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

src/sknnr/transformers/_gbnode_transformer.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -296,11 +296,17 @@ def _set_tree_weights(self, X, y) -> list[NDArray[np.float64]]:
296296

297297
def get_feature_names_out(self) -> NDArray:
298298
check_is_fitted(self, "estimators_")
299-
return np.asarray(
300-
[
301-
f"gb{i}_tree{j}"
302-
for i in range(len(self.estimators_))
303-
for j in range(self.estimators_[i].n_estimators)
304-
],
305-
dtype=object,
306-
)
299+
feature_names = []
300+
for i, est in enumerate(self.estimators_):
301+
# Regression and binary classification have 1 tree per iteration
302+
if est.n_trees_per_iteration_ == 1:
303+
feature_names.extend(
304+
[f"gb{i}_tree{k}" for k in range(est.n_estimators)]
305+
)
306+
# Multi-class classification has n_classes trees per iteration
307+
else:
308+
for j in range(est.n_trees_per_iteration_):
309+
feature_names.extend(
310+
[f"gb{i}_cls{j}_tree{k}" for k in range(est.n_estimators)]
311+
)
312+
return np.asarray(feature_names, dtype=object)

0 commit comments

Comments
 (0)