@@ -532,24 +532,27 @@ def to_netcdf(
532532 return filename
533533
534534 def to_datatree (self ):
535- """Convert InferenceData object to a :class:`~datatree .DataTree`."""
535+ """Convert InferenceData object to a :class:`~xarray .DataTree`."""
536536 try :
537- from datatree import DataTree
538- except ModuleNotFoundError as err :
539- raise ModuleNotFoundError (
540- "datatree must be installed in order to use InferenceData.to_datatree"
537+ from xarray import DataTree
538+ except ImportError as err :
539+ raise ImportError (
540+ "xarray must be have DataTree in order to use InferenceData.to_datatree. "
541+ "Update to xarray>=2024.11.0"
541542 ) from err
542543 return DataTree .from_dict ({group : ds for group , ds in self .items ()})
543544
544545 @staticmethod
545546 def from_datatree (datatree ):
546- """Create an InferenceData object from a :class:`~datatree .DataTree`.
547+ """Create an InferenceData object from a :class:`~xarray .DataTree`.
547548
548549 Parameters
549550 ----------
550551 datatree : DataTree
551552 """
552- return InferenceData (** {group : sub_dt .to_dataset () for group , sub_dt in datatree .items ()})
553+ return InferenceData (
554+ ** {group : child .to_dataset () for group , child in datatree .children .items ()}
555+ )
553556
554557 def to_dict (self , groups = None , filter_groups = None ):
555558 """Convert InferenceData to a dictionary following xarray naming conventions.
@@ -1531,9 +1534,8 @@ def add_groups(
15311534 import xarray as xr
15321535 from xarray_einstats.stats import XrDiscreteRV
15331536 from scipy.stats import poisson
1534- dist = XrDiscreteRV(poisson)
1535- log_lik = xr.Dataset()
1536- log_lik["home_points"] = dist.logpmf(obs["home_points"], np.exp(post["atts"]))
1537+ dist = XrDiscreteRV(poisson, np.exp(post["atts"]))
1538+ log_lik = dist.logpmf(obs["home_points"]).to_dataset(name="home_points")
15371539 idata2.add_groups({"log_likelihood": log_lik})
15381540 idata2
15391541
0 commit comments