@@ -1058,25 +1058,86 @@ def plot_variable_importance(
1058
1058
return ax
1059
1059
1060
1060
1061
- def plot_scatter_submodels (vi_results , func = None , grid = "long" , axes = None ):
1061
+ def plot_scatter_submodels (
1062
+ vi_results : dict ,
1063
+ X : npt .NDArray [np .float64 ],
1064
+ func : Optional [Callable ] = None ,
1065
+ grid : str = "long" ,
1066
+ figsize : Optional [Tuple [float , float ]] = None ,
1067
+ plot_kwargs : Optional [Dict [str , Any ]] = None ,
1068
+ axes : Optional [plt .Axes ] = None ,
1069
+ ):
1070
+ """
1071
+ Plot submodel's predictions against reference-model's predictions.
1072
+
1073
+ Parameters
1074
+ ----------
1075
+ vi_results: Dictionary
1076
+ Dictionary computed with `compute_variable_importance`
1077
+ X : npt.NDArray[np.float64]
1078
+ The covariate matrix.
1079
+ func : Optional[Callable], by default None.
1080
+ Arbitrary function to apply to the predictions. Defaults to the identity function.
1081
+ grid : str or tuple
1082
+ How to arrange the subplots. Defaults to "long", one subplot below the other.
1083
+ Other options are "wide", one subplot next to each other or a tuple indicating the number
1084
+ of rows and columns.
1085
+ plot_kwargs : dict
1086
+ Additional keyword arguments for the plot. Defaults to None.
1087
+ Valid keys are:
1088
+ - color_ref: matplotlib valid color for the 45 degree line
1089
+ - color_scatter: matplotlib valid color for the scatter plot
1090
+ axes : axes
1091
+ Matplotlib axes.
1092
+
1093
+ Returns
1094
+ -------
1095
+ axes: matplotlib axes
1096
+ """
1062
1097
indices = vi_results ["indices" ]
1063
1098
preds = vi_results ["preds" ]
1064
1099
preds_all = vi_results ["preds_all" ]
1100
+ n_vars = len (indices )
1065
1101
1066
1102
if axes is None :
1067
- _ , axes = _get_axes (grid , len (indices ), False , True , None )
1103
+ _ , axes = _get_axes (grid , len (indices ), True , True , figsize )
1104
+
1105
+ if plot_kwargs is None :
1106
+ plot_kwargs = {}
1107
+
1108
+ if hasattr (X , "columns" ) and hasattr (X , "to_numpy" ):
1109
+ labels = X .columns
1110
+
1111
+ if labels is None :
1112
+ labels = np .arange (n_vars ).astype (str )
1113
+ else :
1114
+ labels = np .asarray (labels )
1115
+
1116
+ new_labels = ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels [indices ])]
1068
1117
1069
- func = None
1070
1118
if func is not None :
1071
1119
preds = func (preds )
1072
1120
preds_all = func (preds_all )
1073
1121
1074
1122
min_ = min (np .min (preds ), np .min (preds_all ))
1075
1123
max_ = max (np .max (preds ), np .max (preds_all ))
1076
1124
1077
- for pred , ax in zip (preds , axes .ravel ()):
1078
- ax .plot (pred , preds_all , "." , color = "C0" , alpha = 0.1 )
1079
- ax .axline ([min_ , min_ ], [max_ , max_ ], color = "0.5" )
1125
+ for pred , x_label , ax in zip (preds , new_labels , axes .ravel ()):
1126
+ ax .plot (
1127
+ pred ,
1128
+ preds_all ,
1129
+ marker = plot_kwargs .get ("marker_scatter" , "." ),
1130
+ ls = "" ,
1131
+ color = plot_kwargs .get ("color_scatter" , "C0" ),
1132
+ alpha = plot_kwargs .get ("alpha_scatter" , 0.1 ),
1133
+ )
1134
+ ax .set_xlabel (x_label )
1135
+ ax .axline (
1136
+ [min_ , min_ ],
1137
+ [max_ , max_ ],
1138
+ color = plot_kwargs .get ("color_ref" , "0.5" ),
1139
+ ls = plot_kwargs .get ("ls_ref" , "--" ),
1140
+ )
1080
1141
1081
1142
1082
1143
def generate_sequences (n_vars , i_var , include ):
0 commit comments