Skip to content

Commit b1d6919

Browse files
author
R. Teal Witter
committed
set random state
1 parent 28ba832 commit b1d6919

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

src/shapiq/approximator/sampling.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def __init__(
4646
self.distribution = np.concatenate(([0.0], self.distribution, [0.0]))
4747

4848
self.pairing_trick = pairing_trick
49-
self._rng = np.random.default_rng(seed=random_state)
5049
self.sample_with_replacement = sample_with_replacement
50+
self.set_random_state(random_state)
5151

5252
def get_sampling_probs(self, sizes: np.ndarray) -> np.ndarray:
5353
'''
@@ -324,24 +324,21 @@ def is_coalition_sampled(self) -> np.ndarray:
324324
return self.is_coalition_sampled[self.coalitions_size]
325325

326326
@property
327-
def sampling_adjustment_weights(self) -> np.ndarray:
327+
def coalitions_probability(self) -> np.ndarray:
328328
"""
329329
Returns:
330-
An array with adjusted weight for each coalition ``(n_coalitions,)``
330+
A copy of the sampled coalitions probabilities of shape ``(n_coalitions,)``
331331
"""
332-
return 1 / self.get_sampling_probs(self.coalitions_size)
332+
return self.get_sampling_probs(self.coalitions_size)
333+
333334

334335
@property
335-
def coalitions_probability(self) -> np.ndarray:
336+
def sampling_adjustment_weights(self) -> np.ndarray:
336337
"""
337-
Returns the probability that each coalition was sampled according to the sampling procedure.
338-
339338
Returns:
340-
A copy of the sampled coalitions probabilities of shape ``(n_coalitions,)`` or ``None``
341-
if the coalition probabilities are not available.
342-
339+
An array with adjusted weight for each coalition ``(n_coalitions,)``
343340
"""
344-
return self.get_sampling_probs(self.coalitions_size)
341+
return 1 / self.coalitions_probability
345342

346343
@property
347344
def coalitions_counter(self) -> np.ndarray:
@@ -368,4 +365,12 @@ def empty_coalition_index(self) -> int | None:
368365
except IndexError:
369366
pass
370367
return None
368+
369+
def set_random_state(self, random_state: int | None) -> None:
370+
'''
371+
Set the random state of the sampler.
372+
Args:
373+
random_state (int | None): Random seed for reproducibility
374+
'''
375+
self._rng = np.random.default_rng(seed=random_state)
371376

0 commit comments

Comments
 (0)