Skip to content

Commit a93f00d

Browse files
author
R. Teal Witter
committed
ooh maybe handling of empty and full set probabilities?
1 parent b1d6919 commit a93f00d

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

src/shapiq/approximator/sampling.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def sample(self, budget: int):
237237
budget += budget % 2
238238

239239
# Get sampling probabilities
240-
self.get_scale_for_sampling(budget-2) # minus 2 for empty and full coalitions
240+
self.get_scale_for_sampling(budget-2) # Exclude empty and full coalitions from budget
241241
sizes = np.arange(1, self.n_players)
242242
samples_per_size = self.symmetric_round_even(
243243
self.get_sampling_probs(sizes) * binom(self.n_players, sizes)
@@ -329,8 +329,11 @@ def coalitions_probability(self) -> np.ndarray:
329329
Returns:
330330
A copy of the sampled coalitions probabilities of shape ``(n_coalitions,)``
331331
"""
332-
return self.get_sampling_probs(self.coalitions_size)
333-
332+
probs = self.get_sampling_probs(self.coalitions_size)
333+
# Replace the empty and full coalition probabilities with 1
334+
probs[self.empty_coalition_index] = 1.0
335+
probs[self.full_coalition_index] = 1.0
336+
return probs
334337

335338
@property
336339
def sampling_adjustment_weights(self) -> np.ndarray:
@@ -365,6 +368,19 @@ def empty_coalition_index(self) -> int | None:
365368
except IndexError:
366369
pass
367370
return None
371+
372+
@property
373+
def full_coalition_index(self) -> int | None:
374+
"""
375+
Returns:
376+
The index of the full coalition or ``None`` if the full coalition was not sampled.
377+
"""
378+
try:
379+
if self.coalitions_per_size[-1] >= 1:
380+
return int(np.where(self.coalitions_size == self.n_players)[0][0])
381+
except IndexError:
382+
pass
383+
return None
368384

369385
def set_random_state(self, random_state: int | None) -> None:
370386
'''

0 commit comments

Comments
 (0)