@@ -46,8 +46,8 @@ def __init__(
4646 self .distribution = np .concatenate (([0.0 ], self .distribution , [0.0 ]))
4747
4848 self .pairing_trick = pairing_trick
49- self ._rng = np .random .default_rng (seed = random_state )
5049 self .sample_with_replacement = sample_with_replacement
50+ self .set_random_state (random_state )
5151
5252 def get_sampling_probs (self , sizes : np .ndarray ) -> np .ndarray :
5353 '''
@@ -324,24 +324,21 @@ def is_coalition_sampled(self) -> np.ndarray:
324324 return self .is_coalition_sampled [self .coalitions_size ]
325325
326326 @property
327- def sampling_adjustment_weights (self ) -> np .ndarray :
327+ def coalitions_probability (self ) -> np .ndarray :
328328 """
329329 Returns:
330- An array with adjusted weight for each coalition ``(n_coalitions,)``
330+ A copy of the sampled coalitions probabilities of shape ``(n_coalitions,)``
331331 """
332- return 1 / self .get_sampling_probs (self .coalitions_size )
332+ return self .get_sampling_probs (self .coalitions_size )
333+
333334
334335 @property
335- def coalitions_probability (self ) -> np .ndarray :
336+ def sampling_adjustment_weights (self ) -> np .ndarray :
336337 """
337- Returns the probability that each coalition was sampled according to the sampling procedure.
338-
339338 Returns:
340- A copy of the sampled coalitions probabilities of shape ``(n_coalitions,)`` or ``None``
341- if the coalition probabilities are not available.
342-
339+ An array with adjusted weight for each coalition ``(n_coalitions,)``
343340 """
344- return self .get_sampling_probs ( self . coalitions_size )
341+ return 1 / self .coalitions_probability
345342
346343 @property
347344 def coalitions_counter (self ) -> np .ndarray :
@@ -368,4 +365,12 @@ def empty_coalition_index(self) -> int | None:
368365 except IndexError :
369366 pass
370367 return None
368+
369+ def set_random_state (self , random_state : int | None ) -> None :
370+ '''
371+ Set the random state of the sampler.
372+ Args:
373+ random_state (int | None): Random seed for reproducibility
374+ '''
375+ self ._rng = np .random .default_rng (seed = random_state )
371376
0 commit comments