Skip to content

Commit cbe925a

Browse files
committed
refactored constraint search for groups
1 parent b28330e commit cbe925a

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

generalizedtrees/split.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from logging import getLogger
88
from typing import Collection, Container, Iterable, Protocol, Optional
99
from functools import cached_property
10+
from collections import defaultdict
1011
from operator import itemgetter
1112
import heapq
1213

@@ -761,16 +762,18 @@ def _groups_split_search(self, node, s_data, s_y, all_constraint_candidates, fea
761762

762763
# Get best atomic split for each feature in the group
763764
try:
764-
starting_constraint_dict = {
765-
feature: max([
766-
ScoredItem(
767-
score = self.split_scorer.score(node, BinarySplit(constraint), s_data, s_y),
768-
item = constraint)
769-
for constraint in all_constraint_candidates
770-
if constraint.feature == feature
771-
]).item
772-
for feature in feature_group
773-
}
765+
best_constraint_scores = {}
766+
for constraint in all_constraint_candidates:
767+
f = constraint.feature
768+
if f in feature_group:
769+
score = self.split_scorer.score(node, BinarySplit(constraint), s_data, s_y)
770+
if f not in best_constraint_scores or best_constraint_scores[f].score < score:
771+
best_constraint_scores[f] = ScoredItem(
772+
score=score,
773+
item=constraint)
774+
775+
starting_constraint_dict = {feat: si.item for feat, si in best_constraint_scores.items()}
776+
774777
except:
775778
logger.debug('Failure in group split search')
776779
logger.debug(f'Feature group: {feature_group}')

0 commit comments

Comments
 (0)