@@ -61,20 +61,23 @@ def compute_shapley_values(
6161 :func:`~pydvl.value.shapley.montecarlo.truncated_montecarlo_shapley`.
6262 - ``owen_sampling``: Uses the Owen continuous extension of the utility
6363 function to the unit cube. Implemented in
64- :func:`~pydvl.value.shapley.montecarlo.owen_sampling_shapley`.
65- This method requires an additional parameter `q_max` for the number of
66- subdivisions of the unit interval to use for integration.
64+ :func:`~pydvl.value.shapley.montecarlo.owen_sampling_shapley`. This
65+ method does not take a :class:`~pydvl.value.stopping.StoppingCriterion`
66+ but instead requires a parameter ``q_max`` for the number of subdivisions
67+ of the unit interval to use for integration, and another parameter
68+ ``n_samples`` for the number of subsets to sample for each $q$.
6769 - ``owen_halved``: Same as 'owen_sampling' but uses correlated samples in the
6870 expectation. Implemented in
6971 :func:`~pydvl.value.shapley.montecarlo.owen_sampling_shapley`.
7072 This method requires an additional parameter `q_max` for the number of
71- subdivisions of the interval [0,0.5] to use for integration.
73+ subdivisions of the interval [0,0.5] to use for integration, and another
74+ parameter ``n_samples`` for the number of subsets to sample for each $q$.
7275 - ``group_testing``: estimates differences of Shapley values and solves a
7376 constraint satisfaction problem. High sample complexity, not recommended.
74- Implemented in :func:`~pydvl.value.shapley.gt.group_testing_shapley`. Only
75- accepts :class:`~pydvl.value.stopping.MaxUpdates` (use
76- :func:`~pydvl.value.shapley.gt.num_samples_eps_delta` to compute a bound)
77- and :class:`~pydvl.value.stopping.MaxTime` as stopping criteria .
77+ Implemented in :func:`~pydvl.value.shapley.gt.group_testing_shapley`. This
78+ method does not take a :class:`~pydvl.value.stopping.StoppingCriterion`
79+ but instead requires a parameter ``n_samples`` for the number of
80+ iterations to run .
7881
7982 Additionally, one can use model-specific methods:
8083
@@ -126,8 +129,8 @@ def compute_shapley_values(
126129 elif mode == ShapleyMode .PermutationExact :
127130 return permutation_exact_shapley (u , progress = progress )
128131 elif mode == ShapleyMode .Owen or mode == ShapleyMode .OwenAntithetic :
129- if kwargs .get ("n_iterations " ) is None :
130- raise ValueError ("n_iterations cannot be None for Owen methods" )
132+ if kwargs .get ("n_samples " ) is None :
133+ raise ValueError ("n_samples cannot be None for Owen methods" )
131134 if kwargs .get ("max_q" ) is None :
132135 raise ValueError ("Owen Sampling requires max_q for the outer integral" )
133136
@@ -138,17 +141,17 @@ def compute_shapley_values(
138141 )
139142 return owen_sampling_shapley (
140143 u ,
141- n_iterations = int (kwargs .get ("n_iterations " , - 1 )),
144+ n_samples = int (kwargs .get ("n_samples " , - 1 )),
142145 max_q = int (kwargs .get ("max_q" , - 1 )),
143146 method = method ,
144147 n_jobs = n_jobs ,
145148 )
146149 elif mode == ShapleyMode .KNN :
147150 return knn_shapley (u , progress = progress )
148151 elif mode == ShapleyMode .GroupTesting :
149- n_iterations = kwargs .pop ("n_iterations " )
150- if n_iterations is None :
151- raise ValueError ("n_iterations cannot be None for Group Testing" )
152+ n_samples = kwargs .pop ("n_samples " )
153+ if n_samples is None :
154+ raise ValueError ("n_samples cannot be None for Group Testing" )
152155 epsilon = kwargs .pop ("epsilon" )
153156 if epsilon is None :
154157 raise ValueError ("Group Testing requires error bound epsilon" )
@@ -157,7 +160,7 @@ def compute_shapley_values(
157160 u ,
158161 epsilon = epsilon ,
159162 delta = delta ,
160- n_iterations = n_iterations ,
163+ n_samples = n_samples ,
161164 n_jobs = n_jobs ,
162165 progress = progress ,
163166 ** kwargs ,
0 commit comments