@@ -220,9 +220,7 @@ def plot_lm(
220220 observed_y = extract (dt , group = "observed_data" , var_names = y_obs , combined = False )
221221
222222 if isinstance (ci_prob , (list | tuple | np .ndarray )):
223- x_with_prob = x_pred .expand_dims (dim = {"prob" : ci_prob })
224- else :
225- x_with_prob = x_pred
223+ x_pred = x_pred .expand_dims (dim = {"prob" : ci_prob })
226224
227225 plot_bknd = import_module (f".backend.{ backend } " , package = "arviz_plots" )
228226 bg_color = plot_bknd .get_background_color ()
@@ -245,7 +243,7 @@ def plot_lm(
245243 pc_kwargs ["aes" ].setdefault ("color" , ["__variable__" ])
246244 pc_kwargs = set_wrap_layout (pc_kwargs , plot_bknd , x_pred )
247245 plot_collection = PlotCollection .wrap (
248- x_with_prob ,
246+ x_pred ,
249247 backend = backend ,
250248 ** pc_kwargs ,
251249 )
@@ -420,7 +418,6 @@ def plot_lm(
420418 return plot_collection
421419
422420
423- # This should be moved elsewhere
424421def combine (x_pred , pe_value , ci_data , x_vars , y_vars , smooth , smooth_kwargs = None ):
425422 """
426423 Combine and sort x_pred, pe_value, ci_data into a dataset.
@@ -436,9 +433,10 @@ def combine(x_pred, pe_value, ci_data, x_vars, y_vars, smooth, smooth_kwargs=Non
436433
437434 if smooth_kwargs is None :
438435 smooth_kwargs = {}
439- smooth_kwargs .setdefault ("window_length" , 55 )
440- smooth_kwargs .setdefault ("polyorder" , 2 )
441- smooth_kwargs .setdefault ("n_points" , 200 )
436+
437+ smooth_kwargs .setdefault ("window_length" , 55 )
438+ smooth_kwargs .setdefault ("polyorder" , 2 )
439+ smooth_kwargs .setdefault ("n_points" , 200 )
442440
443441 for xv , yv in zip (x_vars , y_vars ):
444442 old_dim = pe_value [yv ].dims [0 ]
0 commit comments