@@ -237,7 +237,7 @@ def sample(self, budget: int):
237237 budget += budget % 2
238238
239239 # Get sampling probabilities
240- self .get_scale_for_sampling (budget - 2 ) # minus 2 for empty and full coalitions
240+ self .get_scale_for_sampling (budget - 2 ) # Exclude empty and full coalitions from budget
241241 sizes = np .arange (1 , self .n_players )
242242 samples_per_size = self .symmetric_round_even (
243243 self .get_sampling_probs (sizes ) * binom (self .n_players , sizes )
@@ -329,8 +329,11 @@ def coalitions_probability(self) -> np.ndarray:
329329 Returns:
330330 A copy of the sampled coalitions probabilities of shape ``(n_coalitions,)``
331331 """
332- return self .get_sampling_probs (self .coalitions_size )
333-
332+ probs = self .get_sampling_probs (self .coalitions_size )
333+ # Replace the empty and full coalition probabilities with 1
334+ probs [self .empty_coalition_index ] = 1.0
335+ probs [self .full_coalition_index ] = 1.0
336+ return probs
334337
335338 @property
336339 def sampling_adjustment_weights (self ) -> np .ndarray :
@@ -365,6 +368,19 @@ def empty_coalition_index(self) -> int | None:
365368 except IndexError :
366369 pass
367370 return None
371+
372+ @property
373+ def full_coalition_index (self ) -> int | None :
374+ """
375+ Returns:
376+ The index of the full coalition or ``None`` if the full coalition was not sampled.
377+ """
378+ try :
379+ if self .coalitions_per_size [- 1 ] >= 1 :
380+ return int (np .where (self .coalitions_size == self .n_players )[0 ][0 ])
381+ except IndexError :
382+ pass
383+ return None
368384
369385 def set_random_state (self , random_state : int | None ) -> None :
370386 '''
0 commit comments