@@ -1125,8 +1125,11 @@ def plot_scatter_submodels(
11251125 plot_kwargs : dict
11261126 Additional keyword arguments for the plot. Defaults to None.
11271127 Valid keys are:
1128- - color_ref : matplotlib valid color for the 45 degree line
1128+ - marker_scatter : matplotlib valid marker for the scatter plot
11291129 - color_scatter: matplotlib valid color for the scatter plot
1130+ - alpha_scatter: matplotlib valid alpha for the scatter plot
1131+ - color_ref: matplotlib valid color for the 45 degree line
1132+ - ls_ref: matplotlib valid linestyle for the reference line
11301133 axes : axes
11311134 Matplotlib axes.
11321135
@@ -1140,41 +1143,71 @@ def plot_scatter_submodels(
11401143 submodels = np .sort (submodels )
11411144
11421145 indices = vi_results ["indices" ][submodels ]
1143- preds = vi_results ["preds" ][submodels ]
1146+ preds_sub = vi_results ["preds" ][submodels ]
11441147 preds_all = vi_results ["preds_all" ]
11451148
1149+ if labels is None :
1150+ labels = vi_results ["labels" ][submodels ]
1151+
1152+ # handle categorical regression case:
1153+ n_cats = None
1154+ if preds_all .ndim > 2 :
1155+ n_cats = preds_all .shape [- 1 ]
1156+ indices = np .tile (indices , n_cats )
1157+ # labels = np.tile(labels, n_cats)
1158+ # cats = np.repeat(np.arange(n_cats), len(indices) // n_cats)
1159+
11461160 if ax is None :
11471161 _ , ax = _get_axes (grid , len (indices ), True , True , figsize )
11481162
11491163 if plot_kwargs is None :
11501164 plot_kwargs = {}
11511165
1152- if labels is None :
1153- labels = vi_results ["labels" ][submodels ]
1154-
11551166 if func is not None :
1156- preds = func (preds )
1167+ preds_sub = func (preds_sub )
11571168 preds_all = func (preds_all )
11581169
1159- min_ = min (np .min (preds ), np .min (preds_all ))
1160- max_ = max (np .max (preds ), np .max (preds_all ))
1161-
1162- for pred , x_label , axi in zip (preds , labels , ax .ravel ()):
1163- axi .plot (
1164- pred ,
1165- preds_all ,
1166- marker = plot_kwargs .get ("marker_scatter" , "." ),
1167- ls = "" ,
1168- color = plot_kwargs .get ("color_scatter" , "C0" ),
1169- alpha = plot_kwargs .get ("alpha_scatter" , 0.1 ),
1170- )
1171- axi .set_xlabel (x_label )
1172- axi .axline (
1173- [min_ , min_ ],
1174- [max_ , max_ ],
1175- color = plot_kwargs .get ("color_ref" , "0.5" ),
1176- ls = plot_kwargs .get ("ls_ref" , "--" ),
1177- )
1170+ min_ = min (np .min (preds_sub ), np .min (preds_all ))
1171+ max_ = max (np .max (preds_sub ), np .max (preds_all ))
1172+
1173+ # handle categorical regression case:
1174+ if n_cats is not None :
1175+ i = 0
1176+ for cat in range (n_cats ):
1177+ for pred_sub , x_label in zip (preds_sub , labels ):
1178+ ax [i ].plot (
1179+ pred_sub [..., cat ],
1180+ preds_all [..., cat ],
1181+ marker = plot_kwargs .get ("marker_scatter" , "." ),
1182+ ls = "" ,
1183+ color = plot_kwargs .get ("color_scatter" , f"C{ cat } " ),
1184+ alpha = plot_kwargs .get ("alpha_scatter" , 0.1 ),
1185+ )
1186+ ax [i ].set (xlabel = x_label , ylabel = "ref model" , title = f"Category { cat } " )
1187+ ax [i ].axline (
1188+ [min_ , min_ ],
1189+ [max_ , max_ ],
1190+ color = plot_kwargs .get ("color_ref" , "0.5" ),
1191+ ls = plot_kwargs .get ("ls_ref" , "--" ),
1192+ )
1193+ i += 1
1194+ else :
1195+ for pred_sub , x_label , axi in zip (preds_sub , labels , ax .ravel ()):
1196+ axi .plot (
1197+ pred_sub ,
1198+ preds_all ,
1199+ marker = plot_kwargs .get ("marker_scatter" , "." ),
1200+ ls = "" ,
1201+ color = plot_kwargs .get ("color_scatter" , "C0" ),
1202+ alpha = plot_kwargs .get ("alpha_scatter" , 0.1 ),
1203+ )
1204+ axi .set (xlabel = x_label , ylabel = "ref model" )
1205+ axi .axline (
1206+ [min_ , min_ ],
1207+ [max_ , max_ ],
1208+ color = plot_kwargs .get ("color_ref" , "0.5" ),
1209+ ls = plot_kwargs .get ("ls_ref" , "--" ),
1210+ )
11781211 return ax
11791212
11801213
0 commit comments