88from math import ceil
99
1010import numpy as np
11+ import pandas as pd
1112from pandas .api .types import is_bool_dtype
1213
1314from dash import html , dcc , Input , Output
@@ -1091,6 +1092,8 @@ def __init__(
10911092 n_input_cols = 4 ,
10921093 sort_features = "shap" ,
10931094 fill_row_first = True ,
1095+ feature_input_ranges = None ,
1096+ round = 2 ,
10941097 description = None ,
10951098 ** kwargs ,
10961099 ):
@@ -1116,6 +1119,9 @@ def __init__(
11161119 is 'shap' to sort by mean absolute shap value.
11171120 fill_row_first (bool, optional): if True most important features will
11181121 be on top row, if False they will be in most left column.
1122+ feature_input_ranges (dict, optional): dict mapping feature names to
1123+ (min, max) numeric ranges for input fields.
1124+ round (int, optional): number of decimals to round numeric ranges to.
11191125 description (str, optional): Tooltip to display when hover over
11201126 component title. When None default text is shown.
11211127
@@ -1131,6 +1137,8 @@ def __init__(
11311137 explainer , name = "feature-input-index-" + self .name , ** kwargs
11321138 )
11331139 self .index_name = "feature-input-index-" + self .name
1140+ self .feature_input_ranges = feature_input_ranges or {}
1141+ self .round = round
11341142
11351143 self ._feature_callback_inputs = [
11361144 Input ("feature-input-" + feature + "-input-" + self .name , "value" )
@@ -1214,17 +1222,28 @@ def _generate_dash_input(self, col, onehot_cols, onehot_dict, cat_dict):
12141222 )
12151223 else :
12161224 col_values = self .explainer .X [col ][lambda x : x != self .explainer .na_fill ]
1217- if is_bool_dtype (col_values ):
1225+ if col in self .feature_input_ranges :
1226+ min_range , max_range = self .feature_input_ranges [col ]
1227+ elif is_bool_dtype (col_values ):
12181228 min_range = int (col_values .min ())
12191229 max_range = int (col_values .max ())
12201230 else :
1221- min_range = np .round (col_values .min (), 2 )
1222- max_range = np .round (col_values .max (), 2 )
1231+ min_range = np .round (col_values .min (), self .round )
1232+ max_range = np .round (col_values .max (), self .round )
1233+
1234+ if is_bool_dtype (col_values ) or pd .api .types .is_integer_dtype (col_values ):
1235+ step = 1
1236+ else :
1237+ step = 10 ** (- self .round )
12231238 return html .Div (
12241239 [
12251240 dbc .Label (col ),
12261241 dbc .Input (
1227- id = "feature-input-" + col + "-input-" + self .name , type = "number"
1242+ id = "feature-input-" + col + "-input-" + self .name ,
1243+ type = "number" ,
1244+ min = min_range ,
1245+ max = max_range ,
1246+ step = step ,
12281247 ),
12291248 dbc .FormText (f"Range: { min_range } -{ max_range } " )
12301249 if not self .hide_range
0 commit comments