-
Notifications
You must be signed in to change notification settings - Fork 263
Description
I'm using alibi.explainers.PartialDependence plots on a dataset where some high importance features have significant amounts of missing values. The default grid points generation will create a grid of all NaNs if it encounters a NaN value in X, so I used the following work around to generate custom grid points instead. Maybe the intention is to have the user manually specify the search grid in such cases, but in case it isn't, you may want to update the logic for automatically updating the search grid to include NaN handling like this.
I essentially grab all of the unique non-NaN values, then interpolate as necessary to generate the grid points that I want.
I'll also point out that analytically, it is useful in the case of XGBoost and CatBoost to get the model's ceteris paribus prediction from setting the current sweeped feature to NaN as well. This would indicate what the model does when the feature in question is missing.
# Create custom grid points for features with NaN values
print("Creating custom grid points for numerical features...")
custom_grids = {}
grid_resolution = 50 # Number of points we want in the grid
for idx, feat_name in zip(feature_indices, selected_features):
feat_data = X_pdp.iloc[:, idx]
non_nan_data = feat_data.dropna()
if len(non_nan_data) > 0:
unique_vals = np.unique(non_nan_data)
# If we have enough unique values, create an interpolated grid
if len(unique_vals) >= 10:
# Use percentiles to avoid outliers
min_val = np.percentile(non_nan_data, 5)
max_val = np.percentile(non_nan_data, 95)
grid_points = np.linspace(min_val, max_val, grid_resolution)
else:
# Use the unique values directly
grid_points = unique_vals
custom_grids[idx] = grid_points
print(f" {feat_name}: {len(grid_points)} grid points, range [{grid_points.min():.3f}, {grid_points.max():.3f}]")
else:
# Feature is all NaN - use a single point (0)
custom_grids[idx] = np.array([0.0])
print(f" {feat_name}: All NaN - using single grid point [0.0]")