-
-
Notifications
You must be signed in to change notification settings - Fork 12
Refactor plot_lm #343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Refactor plot_lm #343
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #343 +/- ##
==========================================
- Coverage 85.94% 85.29% -0.66%
==========================================
Files 48 48
Lines 6014 6060 +46
==========================================
Hits 5169 5169
- Misses 845 891 +46 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
This looks awesome! So this is doing almost the same as the old |
|
It is doing similar things. Maybe we don't need to have a separate function like |
|
Agree! This looks great! Thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like most of the changes but we need more examples both to showcase that in the docs and for user testing.
| * pe_line-> passed to :func:`~.visuals.line_xy`. | ||
| * ci_band -> passed to :func:`~.visuals.fill_between_y`. | ||
| * ci_bounds -> passed to :func:`~.visuals.line_xy`. Defaults to False | ||
| * ci_line_y -> passed to :func:`~.visuals.ci_line_y`. Defaults to False | ||
| * observed_scatter -> passed to :func:`~.visuals.scatter_xy`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the new names because they are more descriptive, but I think it is still a bit confusing. I would add a short description of each visual element along with the function arguments are passed to
src/arviz_plots/plots/lm_plot.py
Outdated
|
|
||
| for xv, yv in zip(x_vars, y_vars): | ||
| old_dim = pe_value[yv].dims[0] | ||
| y_aligned = pe_value[yv].rename({old_dim: "dim_0"}).reindex(dim_0=x_pred.dim_0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dim_0 has a probability of name clash too high to use it here and without fallback. If we keep this I think we should use a different name.
However, I am not sure we should keep this. I think the goal of this from the example is to have μ_t and μ_h with different dimension names work in here but I think those dimensions should have been named the same from the beginning and match the dimension name for the variables in observed_data and in constant_data
src/arviz_plots/plots/lm_plot.py
Outdated
| values = np.stack([x_pred[xv].values, y_aligned.values, lower, upper], axis=0) | ||
| order = np.argsort(values[0]) | ||
| values_sorted = values[:, order] | ||
| x_sorted = values_sorted[0] | ||
|
|
||
| if smooth: | ||
| x_grid = np.linspace(x_sorted.min(), x_sorted.max(), n_points) | ||
| x_grid[0] = (x_grid[0] + x_grid[1]) / 2 | ||
| values_smoothed = np.zeros((4, n_points)) | ||
| values_smoothed[0] = x_grid | ||
| for i in range(1, 4): | ||
| y_interp = griddata(x_sorted, values_sorted[i], x_grid) | ||
| values_smoothed[i] = savgol_filter(y_interp, axis=0, **smooth_kwargs) | ||
| values_sorted = values_smoothed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have an issue to move all this to arviz-stats?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's easier to discuss what we want here, and once we are happy, move this to arviz-stats. We have an open issue about smoothing HDIs; this could solve it. I think is fine to have this functionality just for plot_lm, not sure we need to expose it as an independent feature.
src/arviz_plots/plots/lm_plot.py
Outdated
| if smooth: | ||
| new_dim = xr.IndexVariable("dim_0", np.arange(n_points)) | ||
| else: | ||
| new_dim = x_pred.dim_0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not do this as I think it will not be correct in a lot of cases. Data is always sorted so the plot looks good and makes sense, so we'd have to take that from x_sorted. However, when there are multiple variables, x_sorted will have a different order for each potentially which would mean the final creation of the dataset of of combined_data would undo the sorting for all variables but one.
I see this as the main argument to change the dimension name to something else. However, sort which does this kind of sorting, removes the coordinate values from the dimension instead. The dimension is conceptually the same, but the coordinate values might no longer make sense and they might also not be aligned between variables/dimensions (if we added support for multiple dimensions in addition to this dim_0 which could be then be facetted on or get aesthetic mappings)
src/arviz_plots/plots/lm_plot.py
Outdated
| # This is intended for categorical x values or few unique values of x | ||
| # where fill_between_y is not appropriate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is what I would add to the description of the visuals I mentioned in another comment. I think moving is fine, but copying also ok with me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will move this up. This was also a reminder to discuss that it would be nice to be able to automatically, or by having a "var_discrete" argument, to have a mix of continuous and discrete/categorical.
src/arviz_plots/plots/lm_plot.py
Outdated
| if smooth: | ||
| new_dim = xr.IndexVariable("dim_0", np.arange(n_points)) | ||
| else: | ||
| new_dim = x_pred.dim_0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not do this as I think it will not be correct in a lot of cases. Data is always sorted so the plot looks good and makes sense, so we'd have to take that from x_sorted. However, when there are multiple variables, x_sorted will have a different order for each potentially which would mean the final creation of the dataset of of combined_data would undo the sorting for all variables but one.
I see this as the main argument to change the dimension name to something else. However, sort which does this kind of sorting, removes the coordinate values from the dimension instead. The dimension is conceptually the same, but the coordinate values might no longer make sense and they might also not be aligned between variables/dimensions (if we added support for multiple dimensions in addition to this dim_0 which could be then be facetted on or get aesthetic mappings)
|
I have addressed some of your comments @OriolAbril, the easy ones. Regarding what to do with the |
| If None (default), and if group is "predictions", all variables corresponding to x data | ||
| in "predictions_constant_data" group are used. If group is "posterior_predictive", | ||
| x is used. | ||
| x : str optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| x : str optional | |
| x : str, optional |
| x is used. | ||
| x : str optional | ||
| Independent variable. If None, use the first variable in constant_data group. | ||
| y : str optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| y : str optional | |
| y : str, optional |
| Independent variable. If None, use the first variable in constant_data group. | ||
| y : str optional | ||
| Response variable or linear term. If None, use the first variable in observed_data group. | ||
| y_obs : str or DataArray, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to explain the difference between y and y_obs a bit more. They're looked up in different groups, right? I would mention it.
Also, why does it make different assumptions about y and y_obs? I can imagine it's related to variables available in the different groups, but I'm not sure.
| If None (default), interpret var_names as the real variables names. | ||
| If “like”, interpret var_names as substrings of the real variables names. | ||
| If “regex”, interpret var_names as regular expressions on the real variables names. | ||
| It is used for any of y, x, y_pred, and x_pred if they are strings or lists of strings. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are y_pred and x_pred still relevant to the user?
| * central_line -> passed to :func:`~.visuals.line_xy`. | ||
| * ci_fill -> passed to :func:`~.visuals.fill_between_y`. | ||
| * scatter -> passed to :func:`~.visuals.scatter_xy`. | ||
| * pe_line-> passed to :func:`~.visuals.line_xy`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are spaces between lines and around -> meaningful for doc building? Asking to make sure things are lay out as expected.
| y = process_group_variables_coords( | ||
| dt, group="observed_data", var_names=y, filter_vars=filter_vars, coords=coords | ||
| ) | ||
| y = list(obs_data.data_vars)[:1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens with x and y if either observed_data or constant_data do not exist? Would get_group work? Maybe it makes sense to first check availability of the groups, then availability of data vars within the groups, and then raise appropriate errors if needed?
| ) | ||
| if isinstance(ci_prob, Sequence): | ||
| aes_by_visuals.setdefault("ci_line", {"alpha"}) | ||
| if isinstance(ci_prob, (list | tuple | np.ndarray)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see this check in several places. As far as I understand, you use it to act different in the presence of multiple intervals instead of one.
Would it make sense to convert ci_prob to a one-dimensional numpy array early on, and then always act as if they are many? Perhaps something custom can be done if len(arr) == 1.
src/arviz_plots/plots/lm_plot.py
Outdated
| pe_value = azs.mode(y_pred, dim=central_line_dims, **stats.get("point_estimate", {})) | ||
|
|
||
| lines = plot_bknd.get_default_aes("linestyle", 2, {}) | ||
| ds_combined = combine(x_pred, pe_value, ci_data, x, y, smooth, stats.get("smooth", {})) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring mentions the valid keys for stats are credible_interval and point_estimate... something needs an update?
| Combine and sort x_pred, pe_value, ci_data into a dataset. | ||
| The resulting dataset will have a dimension plot_axis=['x','y','y_bottom','y_top'], | ||
| and will sort each variable by its x values, and optionally smooth along dim_0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think dim_0 is wrong?
|
I did some experiments, and I like the new API, I am trying to use xarray more to simplify a bit the smoothing and data wrangling. My main doubt (whose answer might be use a different plot) was about trying to use so we can compare the bands that come from the posterior predictive with the observations and see if we can identify specific regions in temperature or humidity where the model doesn't work well? |
I started working on adding a smooth option to
plot_lmas presented in #317. But after playing a little bit with the function, I decided to incorporate other changes, which may be a better fit given the design and expectations of other plots in ArviZ. But given thatplot_lmhas always been a weird plot for us, and I may be too focused on what I want to do, and then miss other uses. So, to make things more concrete, I will show the steps of the two main use cases I have in mind. Notice thecombinefunction I am adding in this PR; I may be overcomplicating things there, or not making the correct assumptions (e.g., it fails when ci_prob is a list instead of a float).A common pattern in ArviZ-plots is to create a DataTree first, plot second. So I follow that pattern here. For instance, the arguments "x" and "y" are equivalent to "var_names" in other plots.
For the most straightforward example, we have a linear regression with only one predictor. Here I am using the bikes dataset and a NegativeBinomial family, the model is essentially.
rented ~ temperatureOnce the model is sampled, we do
And then we can plot the predictions with
Multiple bands
Notice the y-axis labels are not correct yet.
More often than not. We have regression models with more than one covariate. For those cases, what we want to plot is the marginal predictions (or marginals eta/linear-terms).
rented ~ temperature + humidityFor such models, we need to compute the marginals somehow. PyMC-BART has custom functions for this that take advantage of the tree structure to efficiently compute predictions. Bambi has the interpret module for the slope/predictions. I mention those for context, but also because, at least in the PyMC-BART case, it would be nice to replace the plotting part with arviz-plots.
Assuming we want to compute the marginals manually, we could do something like this
After that, we can plot the marginal values i.e., how the mean of the rented bikes changes with temperature when we keep the humidity at its mean value. and the other way around
We may want to plot the posterior means, instead of predictions
📚 Documentation preview 📚: https://arviz-plots--343.org.readthedocs.build/en/343/