@@ -1125,8 +1125,11 @@ def plot_scatter_submodels(
1125
1125
plot_kwargs : dict
1126
1126
Additional keyword arguments for the plot. Defaults to None.
1127
1127
Valid keys are:
1128
- - color_ref : matplotlib valid color for the 45 degree line
1128
+ - marker_scatter : matplotlib valid marker for the scatter plot
1129
1129
- color_scatter: matplotlib valid color for the scatter plot
1130
+ - alpha_scatter: matplotlib valid alpha for the scatter plot
1131
+ - color_ref: matplotlib valid color for the 45 degree line
1132
+ - ls_ref: matplotlib valid linestyle for the reference line
1130
1133
axes : axes
1131
1134
Matplotlib axes.
1132
1135
@@ -1140,41 +1143,71 @@ def plot_scatter_submodels(
1140
1143
submodels = np .sort (submodels )
1141
1144
1142
1145
indices = vi_results ["indices" ][submodels ]
1143
- preds = vi_results ["preds" ][submodels ]
1146
+ preds_sub = vi_results ["preds" ][submodels ]
1144
1147
preds_all = vi_results ["preds_all" ]
1145
1148
1149
+ if labels is None :
1150
+ labels = vi_results ["labels" ][submodels ]
1151
+
1152
+ # handle categorical regression case:
1153
+ n_cats = None
1154
+ if preds_all .ndim > 2 :
1155
+ n_cats = preds_all .shape [- 1 ]
1156
+ indices = np .tile (indices , n_cats )
1157
+ # labels = np.tile(labels, n_cats)
1158
+ # cats = np.repeat(np.arange(n_cats), len(indices) // n_cats)
1159
+
1146
1160
if ax is None :
1147
1161
_ , ax = _get_axes (grid , len (indices ), True , True , figsize )
1148
1162
1149
1163
if plot_kwargs is None :
1150
1164
plot_kwargs = {}
1151
1165
1152
- if labels is None :
1153
- labels = vi_results ["labels" ][submodels ]
1154
-
1155
1166
if func is not None :
1156
- preds = func (preds )
1167
+ preds_sub = func (preds_sub )
1157
1168
preds_all = func (preds_all )
1158
1169
1159
- min_ = min (np .min (preds ), np .min (preds_all ))
1160
- max_ = max (np .max (preds ), np .max (preds_all ))
1161
-
1162
- for pred , x_label , axi in zip (preds , labels , ax .ravel ()):
1163
- axi .plot (
1164
- pred ,
1165
- preds_all ,
1166
- marker = plot_kwargs .get ("marker_scatter" , "." ),
1167
- ls = "" ,
1168
- color = plot_kwargs .get ("color_scatter" , "C0" ),
1169
- alpha = plot_kwargs .get ("alpha_scatter" , 0.1 ),
1170
- )
1171
- axi .set_xlabel (x_label )
1172
- axi .axline (
1173
- [min_ , min_ ],
1174
- [max_ , max_ ],
1175
- color = plot_kwargs .get ("color_ref" , "0.5" ),
1176
- ls = plot_kwargs .get ("ls_ref" , "--" ),
1177
- )
1170
+ min_ = min (np .min (preds_sub ), np .min (preds_all ))
1171
+ max_ = max (np .max (preds_sub ), np .max (preds_all ))
1172
+
1173
+ # handle categorical regression case:
1174
+ if n_cats is not None :
1175
+ i = 0
1176
+ for cat in range (n_cats ):
1177
+ for pred_sub , x_label in zip (preds_sub , labels ):
1178
+ ax [i ].plot (
1179
+ pred_sub [..., cat ],
1180
+ preds_all [..., cat ],
1181
+ marker = plot_kwargs .get ("marker_scatter" , "." ),
1182
+ ls = "" ,
1183
+ color = plot_kwargs .get ("color_scatter" , f"C{ cat } " ),
1184
+ alpha = plot_kwargs .get ("alpha_scatter" , 0.1 ),
1185
+ )
1186
+ ax [i ].set (xlabel = x_label , ylabel = "ref model" , title = f"Category { cat } " )
1187
+ ax [i ].axline (
1188
+ [min_ , min_ ],
1189
+ [max_ , max_ ],
1190
+ color = plot_kwargs .get ("color_ref" , "0.5" ),
1191
+ ls = plot_kwargs .get ("ls_ref" , "--" ),
1192
+ )
1193
+ i += 1
1194
+ else :
1195
+ for pred_sub , x_label , axi in zip (preds_sub , labels , ax .ravel ()):
1196
+ axi .plot (
1197
+ pred_sub ,
1198
+ preds_all ,
1199
+ marker = plot_kwargs .get ("marker_scatter" , "." ),
1200
+ ls = "" ,
1201
+ color = plot_kwargs .get ("color_scatter" , "C0" ),
1202
+ alpha = plot_kwargs .get ("alpha_scatter" , 0.1 ),
1203
+ )
1204
+ axi .set (xlabel = x_label , ylabel = "ref model" )
1205
+ axi .axline (
1206
+ [min_ , min_ ],
1207
+ [max_ , max_ ],
1208
+ color = plot_kwargs .get ("color_ref" , "0.5" ),
1209
+ ls = plot_kwargs .get ("ls_ref" , "--" ),
1210
+ )
1178
1211
return ax
1179
1212
1180
1213
0 commit comments