22
22
23
23
from abc import ABC
24
24
from typing import (
25
+ Any ,
25
26
Dict ,
26
27
List ,
28
+ Mapping ,
27
29
Optional ,
28
30
Sequence ,
29
31
Set ,
@@ -47,7 +49,87 @@ class BackendError(Exception):
47
49
pass
48
50
49
51
50
- class BaseTrace (ABC ):
52
+ class IBaseTrace (ABC , Sized ):
53
+ """Minimal interface needed to record and access draws and stats for one MCMC chain."""
54
+
55
+ chain : int
56
+ """Chain number."""
57
+
58
+ varnames : List [str ]
59
+ """Names of tracked variables."""
60
+
61
+ sampler_vars : List [Dict [str , type ]]
62
+ """Sampler stats for each sampler."""
63
+
64
+ def __len__ (self ):
65
+ raise NotImplementedError ()
66
+
67
+ def get_values (self , varname : str , burn = 0 , thin = 1 ) -> np .ndarray :
68
+ """Get values from trace.
69
+
70
+ Parameters
71
+ ----------
72
+ varname: str
73
+ burn: int
74
+ thin: int
75
+
76
+ Returns
77
+ -------
78
+ A NumPy array
79
+ """
80
+ raise NotImplementedError ()
81
+
82
+ def get_sampler_stats (self , stat_name : str , sampler_idx : Optional [int ] = None , burn = 0 , thin = 1 ):
83
+ """Get sampler statistics from the trace.
84
+
85
+ Parameters
86
+ ----------
87
+ stat_name: str
88
+ sampler_idx: int or None
89
+ burn: int
90
+ thin: int
91
+
92
+ Returns
93
+ -------
94
+ If the `sampler_idx` is specified, return the statistic with
95
+ the given name in a numpy array. If it is not specified and there
96
+ is more than one sampler that provides this statistic, return
97
+ a numpy array of shape (m, n), where `m` is the number of
98
+ such samplers, and `n` is the number of samples.
99
+ """
100
+ raise NotImplementedError ()
101
+
102
+ def _slice (self , idx : slice ) -> "IBaseTrace" :
103
+ """Slice trace object."""
104
+ raise NotImplementedError ()
105
+
106
+ def point (self , idx : int ) -> Dict [str , np .ndarray ]:
107
+ """Return dictionary of point values at `idx` for current chain
108
+ with variables names as keys.
109
+ """
110
+ raise NotImplementedError ()
111
+
112
+ def record (self , draw : Mapping [str , np .ndarray ], stats : Sequence [Mapping [str , Any ]]):
113
+ """Record results of a sampling iteration.
114
+
115
+ Parameters
116
+ ----------
117
+ draw: dict
118
+ Values mapped to variable names
119
+ stats: list of dicts
120
+ The diagnostic values for each sampler
121
+ """
122
+ raise NotImplementedError ()
123
+
124
+ def close (self ):
125
+ """Close the backend.
126
+
127
+ This is called after sampling has finished.
128
+ """
129
+ pass
130
+
131
+
132
+ class BaseTrace (IBaseTrace ):
51
133
"""Base trace object
52
134
53
135
Parameters
@@ -127,25 +209,6 @@ def setup(self, draws, chain, sampler_vars=None) -> None:
127
209
self ._set_sampler_vars (sampler_vars )
128
210
self ._is_base_setup = True
129
211
130
- def record (self , point , sampler_states = None ):
131
- """Record results of a sampling iteration.
132
-
133
- Parameters
134
- ----------
135
- point: dict
136
- Values mapped to variable names
137
- sampler_states: list of dicts
138
- The diagnostic values for each sampler
139
- """
140
- raise NotImplementedError
141
-
142
- def close (self ):
143
- """Close the database backend.
144
-
145
- This is called after sampling has finished.
146
- """
147
- pass
148
-
149
212
# Selection methods
150
213
151
214
def __getitem__ (self , idx ):
@@ -157,24 +220,6 @@ def __getitem__(self, idx):
157
220
except (ValueError , TypeError ): # Passed variable or variable name.
158
221
raise ValueError ("Can only index with slice or integer" )
159
222
160
- def __len__ (self ):
161
- raise NotImplementedError
162
-
163
- def get_values (self , varname , burn = 0 , thin = 1 ):
164
- """Get values from trace.
165
-
166
- Parameters
167
- ----------
168
- varname: str
169
- burn: int
170
- thin: int
171
-
172
- Returns
173
- -------
174
- A NumPy array
175
- """
176
- raise NotImplementedError
177
-
178
223
def get_sampler_stats (self , stat_name , sampler_idx = None , burn = 0 , thin = 1 ):
179
224
"""Get sampler statistics from the trace.
180
225
@@ -220,19 +265,9 @@ def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
220
265
"""Get sampler statistics."""
221
266
raise NotImplementedError ()
222
267
223
- def _slice (self , idx : Union [int , slice ]):
224
- """Slice trace object."""
225
- raise NotImplementedError ()
226
-
227
- def point (self , idx : int ) -> Dict [str , np .ndarray ]:
228
- """Return dictionary of point values at `idx` for current chain
229
- with variables names as keys.
230
- """
231
- raise NotImplementedError ()
232
-
233
268
@property
234
269
def stat_names (self ) -> Set [str ]:
235
- names = set ()
270
+ names : Set [ str ] = set ()
236
271
for vars in self .sampler_vars or []:
237
272
names .update (vars .keys ())
238
273
@@ -290,7 +325,7 @@ class MultiTrace:
290
325
List of variable names in the trace(s)
291
326
"""
292
327
293
- def __init__ (self , straces : Sequence [BaseTrace ]):
328
+ def __init__ (self , straces : Sequence [IBaseTrace ]):
294
329
if len ({t .chain for t in straces }) != len (straces ):
295
330
raise ValueError ("Chains are not unique." )
296
331
self ._straces = {t .chain : t for t in straces }
@@ -386,7 +421,7 @@ def stat_names(self) -> Set[str]:
386
421
sampler_vars = [s .sampler_vars for s in self ._straces .values ()]
387
422
if not all (svars == sampler_vars [0 ] for svars in sampler_vars ):
388
423
raise ValueError ("Inividual chains contain different sampler stats" )
389
- names = set ()
424
+ names : Set [ str ] = set ()
390
425
for trace in self ._straces .values ():
391
426
if trace .sampler_vars is None :
392
427
continue
@@ -472,7 +507,7 @@ def get_sampler_stats(
472
507
]
473
508
return _squeeze_cat (results , combine , squeeze )
474
509
475
- def _slice (self , slice ):
510
+ def _slice (self , slice : slice ):
476
511
"""Return a new MultiTrace object sliced according to `slice`."""
477
512
new_traces = [trace ._slice (slice ) for trace in self ._straces .values ()]
478
513
trace = MultiTrace (new_traces )
0 commit comments