@@ -108,6 +108,34 @@ def datatree_binary(seed=17):
108108 )
109109
110110
111+ def datatree_regression (seed = 17 ):
112+ """Generate a DataTree for regression data."""
113+ from scipy .stats import norm
114+
115+ rng = np .random .default_rng (seed )
116+ n_obs = 100
117+ true_sigma = 0.9
118+ true_mu = 2 * np .linspace (- 1 , 1 , n_obs )
119+ observed_data = true_mu + rng .normal (0 , true_sigma , size = n_obs )
120+
121+ posterior_sigma = rng .normal (true_sigma , 0.1 , size = (4 , 500 ))
122+ posterior_sigma = np .abs (posterior_sigma )
123+
124+ posterior_mu = rng .normal (true_mu , true_sigma * 0.5 , size = (4 , 500 , n_obs ))
125+ posterior_predictive = rng .normal (posterior_mu , true_sigma , size = (4 , 500 , n_obs ))
126+ log_likelihood = norm (posterior_mu , true_sigma ).logpdf (observed_data )
127+
128+ return from_dict (
129+ {
130+ "posterior" : {"mu" : posterior_mu , "sigma" : posterior_sigma },
131+ "posterior_predictive" : {"y" : posterior_predictive },
132+ "observed_data" : {"y" : observed_data },
133+ "log_likelihood" : {"y" : log_likelihood },
134+ },
135+ dims = {"y" : ["obs_dim" ]},
136+ )
137+
138+
111139def datatree_4d (seed = 31 ):
112140 """Generate a DataTree with a 4D posterior."""
113141 rng = np .random .default_rng (seed )
@@ -167,7 +195,7 @@ def cmp():
167195
168196
169197def fake_dt ():
170- """Generate a fake prior/posterior DataTreeZ ."""
198+ """Generate a fake prior/posterior DataTree ."""
171199 rng = np .random .default_rng (42 )
172200
173201 return from_dict (
0 commit comments