Skip to content

Commit 21f48c1

Browse files
author
R. Teal Witter
committed
unique sampling matrix, bug fix
1 parent d7f1a8d commit 21f48c1

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

src/shapiq/approximator/sampling.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,12 @@ def add_one_sample(self, indices: Sequence[int]):
137137
indices (Sequence[int]): Indices of players in the coalition.
138138
Returns:
139139
None: Sample is stored in self.coalitions_matrix and self.sampled_coalitions_dict
140-
'''
141-
self.coalitions_matrix[self._coalition_idx, indices] = 1
142-
if tuple(sorted(indices)) not in self.sampled_coalitions_dict:
140+
'''
141+
if tuple(sorted(indices)) not in self.sampled_coalitions_dict:
142+
self.coalitions_matrix[self._coalition_idx, indices] = 1
143143
self.sampled_coalitions_dict[tuple(sorted(indices))] = 0
144-
self.sampled_coalitions_dict[tuple(sorted(indices))] += 1
145-
self._coalition_idx += 1
144+
self._coalition_idx += 1
145+
self.sampled_coalitions_dict[tuple(sorted(indices))] += 1
146146

147147
def symmetric_round_even(self, x: np.ndarray) -> np.ndarray:
148148
'''
@@ -323,7 +323,7 @@ def is_coalition_size_sampled(self) -> np.ndarray:
323323
The Boolean array whether the coalition size was sampled ``(n_players + 1,)``
324324
"""
325325
is_size_sampled = np.zeros(self.n + 1, dtype=bool)
326-
is_size_sampled[0] = is_size_sampled[self.n] = True
326+
is_size_sampled[0] = is_size_sampled[self.n] = False
327327
is_size_sampled[1:-1] = (self.samples_per_size != binom(self.n, np.arange(1, self.n)))
328328
return is_size_sampled
329329

@@ -339,7 +339,7 @@ def is_coalition_sampled(self) -> np.ndarray:
339339
def coalitions_probability(self) -> np.ndarray:
340340
"""
341341
Returns:
342-
A copy of the sampled coalitions probabilities of shape ``(n_coalitions,)``
342+
The probability of sampling each coalition ``(n_coalitions,)``
343343
"""
344344
probs = self.get_sampling_probs(self.coalitions_size)
345345
# Replace the empty and full coalition probabilities with 1
@@ -418,4 +418,3 @@ def set_random_state(self, random_state: int | None) -> None:
418418
random_state (int | None): Random seed for reproducibility
419419
'''
420420
self._rng = np.random.default_rng(seed=random_state)
421-

0 commit comments

Comments
 (0)