@@ -8,48 +8,82 @@ def check_lengths_same(*args):
88 raise ValueError (f"All tuple arguments must have the same length, but lengths are { tuple (map (len , args ))} ." )
99
1010
11- def check_posterior_prior_shapes ( post_variables : Tensor , prior_variables : Tensor ):
11+ def check_prior_shapes ( variables : Tensor ):
1212 """
13- Checks requirements for the shapes of posterior and prior draws as
14- necessitated by most diagnostic functions.
13+ Checks the shape of posterior draws as required by most diagnostic functions
1514
1615 Parameters
1716 ----------
18- post_samples : Tensor of shape (num_data_sets, num_post_draws, num_params)
19- The posterior draws obtained from num_data_sets
20- prior_samples : Tensor of shape (num_data_sets, num_params)
21- The prior draws obtained for generating num_data_sets
22-
23- Raises
24- ------
25- ShapeError
26- If there is a deviation form the expected shapes of `post_samples` and `prior_samples`.
17+ variables : Tensor of shape (num_data_sets, num_params)
18+ The prior_samples from generating num_data_sets
2719 """
2820
29- if len (post_variables .shape ) != 3 :
21+ if len (variables .shape ) != 2 :
3022 raise ShapeError (
31- "post_samples should be a 3-dimensional array, with the "
32- "first dimension being the number of (simulated) data sets, "
33- "the second dimension being the number of posterior draws per data set, "
34- "and the third dimension being the number of parameters (marginal distributions), "
35- f"but your input has dimensions { len (post_variables .shape )} "
23+ "prior_samples samples should be a 2-dimensional array, with the "
24+ "first dimension being the number of (simulated) data sets / prior_samples draws "
25+ "and the second dimension being the number of variables, "
26+ f"but your input has dimensions { len (variables .shape )} "
3627 )
37- elif len (prior_variables .shape ) != 2 :
28+
29+
30+ def check_estimates_shapes (variables : Tensor ):
31+ """
32+ Checks the shape of model-generated predictions (posterior draws, point estimates)
33+ as required by most diagnostic functions
34+
35+ Parameters
36+ ----------
37+ variables : Tensor of shape (num_data_sets, num_post_draws, num_params)
38+ The prior_samples from generating num_data_sets
39+ """
40+ if len (variables .shape ) != 2 and len (variables .shape ) != 3 :
3841 raise ShapeError (
39- "prior_samples should be a 2-dimensional array, with the "
40- "first dimension being the number of (simulated) data sets / prior draws "
41- "and the second dimension being the number of parameters (marginal distributions), "
42- f"but your input has dimensions { len (prior_variables .shape )} "
42+ "estimates should be a 2- or 3-dimensional array, with the "
43+ "first dimension being the number of data sets, "
44+ "(optional) second dimension the number of posterior draws per data set, "
45+ "and the last dimension the number of estimated variables, "
46+ f"but your input has dimensions { len (variables .shape )} "
4347 )
44- elif post_variables .shape [0 ] != prior_variables .shape [0 ]:
48+
49+
50+ def check_consistent_shapes (estimates : Tensor , prior_samples : Tensor ):
51+ """
52+ Checks whether the model-generated predictions (posterior draws, point estimates) and
53+ prior_samples have consistent leading (num_data_sets) and trailing (num_params) dimensions
54+ """
55+ if estimates .shape [0 ] != prior_samples .shape [0 ]:
4556 raise ShapeError (
46- "The number of elements over the first dimension of post_samples and prior_samples"
47- f"should match, but post_samples has { post_variables .shape [0 ]} and prior_samples has "
48- f"{ prior_variables .shape [0 ]} elements, respectively."
57+ "The number of elements over the first dimension of estimates and prior_samples"
58+ f"should match, but estimates have { estimates .shape [0 ]} and prior_samples has "
59+ f"{ prior_samples .shape [0 ]} elements, respectively."
4960 )
50- elif post_variables .shape [- 1 ] != prior_variables .shape [- 1 ]:
61+ if estimates .shape [- 1 ] != prior_samples .shape [- 1 ]:
5162 raise ShapeError (
52- "The number of elements over the last dimension of post_samples and prior_samples"
53- f"should match, but post_samples has { post_variables .shape [1 ]} and prior_samples has "
54- f"{ prior_variables .shape [- 1 ]} elements, respectively."
63+ "The number of elements over the last dimension of estimates and prior_samples"
64+ f"should match, but estimates has { estimates .shape [0 ]} and prior_samples has "
65+ f"{ prior_samples .shape [0 ]} elements, respectively."
5566 )
67+
68+
69+ def check_estimates_prior_shapes (estimates : Tensor , prior_samples : Tensor ):
70+ """
71+ Checks requirements for the shapes of estimates and prior_samples draws as
72+ necessitated by most diagnostic functions.
73+
74+ Parameters
75+ ----------
76+ estimates : Tensor of shape (num_data_sets, num_post_draws, num_params) or (num_data_sets, num_params)
77+ The model-generated predictions (posterior draws, point estimates) obtained from num_data_sets
78+ prior_samples : Tensor of shape (num_data_sets, num_params)
79+ The prior_samples draws obtained for generating num_data_sets
80+
81+ Raises
82+ ------
83+ ShapeError
84+ If there is a deviation form the expected shapes of `estimates` and `estimates`.
85+ """
86+
87+ check_estimates_shapes (estimates )
88+ check_prior_shapes (prior_samples )
89+ check_consistent_shapes (estimates , prior_samples )
0 commit comments