|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | from typing import TYPE_CHECKING, cast |
6 | | -from typing_extensions import override |
7 | 6 |
|
8 | 7 | import numpy as np |
9 | 8 | from numpy.random import default_rng |
@@ -47,14 +46,22 @@ def __init__( |
47 | 46 | """Initializes the class. |
48 | 47 |
|
49 | 48 | Args: |
50 | | - model: The model to explain as a callable function expecting data points as input and |
| 49 | + model: The model to explain as a callable function expecting a data points as input and |
51 | 50 | returning the model's predictions. |
52 | | - data: The background data to use for the explainer as a ``np.ndarray`` of shape ``(n_samples, n_features)``. |
53 | | - x: The explanation point as a ``np.ndarray`` of shape ``(1, n_features)`` or ``(n_features,)``. Defaults to ``None``. |
54 | | - sample_size: Number of Monte Carlo samples for imputation. Defaults to ``100``. |
55 | | - random_state: The random state to use for sampling. Defaults to ``None``. |
56 | | - verbose: A flag to enable verbose imputation, which will print a progress bar for model evaluation. |
57 | | - Note that this can slow down the imputation process. Defaults to ``False``. |
| 51 | +
|
| 52 | + data: The background data to use for the explainer as a two-dimensional array with shape |
| 53 | + ``(n_samples, n_features)``. |
| 54 | +
|
| 55 | + x: The explanation point as a ``np.ndarray`` of shape ``(1, n_features)`` or |
| 56 | + ``(n_features,)``. |
| 57 | +
|
| 58 | + sample_size: The number of Monte Carlo samples to draw from the conditional background |
| 59 | + data for imputation. |
| 60 | +
|
| 61 | + random_state: An optional random seed for reproducibility. |
| 62 | +
|
| 63 | + verbose: A flag to enable verbose imputation, which will print a progress bar for model |
| 64 | + evaluation. Note that this can slow down the imputation process. |
58 | 65 |
|
59 | 66 | Raises: |
60 | 67 | CategoricalFeatureError: If the background data contains any categorical features. |
@@ -207,7 +214,6 @@ def _sample_monte_carlo( |
207 | 214 |
|
208 | 215 | return samples_all_coalitions |
209 | 216 |
|
210 | | - @override |
211 | 217 | def value_function(self, coalitions: npt.NDArray[np.bool]) -> npt.NDArray[np.floating]: |
212 | 218 | """Imputes the missing values of a data point and gets predictions for all coalitions. |
213 | 219 |
|
|
0 commit comments