1
1
"""Text file trace backend
2
2
3
- After sampling with NDArray backend, save results as text files.
3
+ Store sampling values as CSV files.
4
4
5
- As this other backends, this can be used by passing the backend instance
6
- to `sample`.
5
+ File format
6
+ -----------
7
7
8
- >>> import pymc3 as pm
9
- >>> db = pm.backends.Text('test')
10
- >>> trace = pm.sample(..., trace=db)
8
+ Sampling values for each chain are saved in a separate file (under a
9
+ directory specified by the `name` argument). The rows correspond to
10
+ sampling iterations. The column names consist of variable names and
11
+ index labels. For example, the heading
11
12
12
- Or sampling can be performed with the default NDArray backend and then
13
- dumped to text files after.
13
+ x,y__0_0,y__0_1,y__1_0,y__1_1,y__2_0,y__2_1
14
14
15
- >>> from pymc3.backends import text
16
- >>> trace = pm.sample(...)
17
- >>> text.dump('test', trace)
18
-
19
- Database format
20
- ---------------
21
-
22
- For each chain, a directory named `chain-N` is created. In this
23
- directory, one file per variable is created containing the values of the
24
- object. To deal with multidimensional variables, the array is reshaped
25
- to one dimension before saving with `numpy.savetxt`. The shape and dtype
26
- information is saved in a json file in the same directory and is used to
27
- load the database back again using `numpy.loadtxt`.
15
+ represents two variables, x and y, where x is a scalar and y has a
16
+ shape of (3, 2).
28
17
"""
29
- import os
30
- import glob
31
- import json
18
+ from glob import glob
32
19
import numpy as np
20
+ import os
21
+ import pandas as pd
22
+ import warnings
33
23
34
24
from ..backends import base
35
- from ..backends .ndarray import NDArray
36
25
37
26
38
- class Text (NDArray ):
39
- """Text storage
27
+ class Text (base . BaseTrace ):
28
+ """Text trace object
40
29
41
30
Parameters
42
31
----------
@@ -53,102 +42,207 @@ def __init__(self, name, model=None, vars=None):
53
42
os .mkdir (name )
54
43
super (Text , self ).__init__ (name , model , vars )
55
44
56
- def close (self ):
57
- super (Text , self ).close ()
58
- _dump_trace (self .name , self )
59
-
60
-
61
- def dump (name , trace , chains = None ):
62
- """Store NDArray trace as text database.
63
-
64
- Parameters
65
- ----------
66
- name : str
67
- Name of directory to store text files
68
- trace : MultiTrace of NDArray traces
69
- Result of MCMC run with default NDArray backend
70
- chains : list
71
- Chains to dump. If None, all chains are dumped.
72
- """
73
- if not os .path .exists (name ):
74
- os .mkdir (name )
75
- if chains is None :
76
- chains = trace .chains
77
- for chain in chains :
78
- _dump_trace (name , trace ._traces [chain ])
79
-
45
+ self .flat_names = {v : _create_flat_names (v , shape )
46
+ for v , shape in self .var_shapes .items ()}
47
+
48
+ self .filename = None
49
+ self ._fh = None
50
+ self .df = None
51
+
52
+ ## Sampling methods
53
+
54
+ def setup (self , draws , chain ):
55
+ """Perform chain-specific setup.
56
+
57
+ Parameters
58
+ ----------
59
+ draws : int
60
+ Expected number of draws
61
+ chain : int
62
+ Chain number
63
+ """
64
+ self .chain = chain
65
+ self .filename = os .path .join (self .name , 'chain-{}.csv' .format (chain ))
66
+
67
+ cnames = [fv for v in self .varnames for fv in self .flat_names [v ]]
68
+
69
+ if os .path .exists (self .filename ):
70
+ with open (self .filename ) as fh :
71
+ prev_cnames = next (fh ).strip ().split (',' )
72
+ if prev_cnames != cnames :
73
+ raise base .BackendError (
74
+ "Previous file '{}' has different variables names "
75
+ "than current model." .format (self .filename ))
76
+ self ._fh = open (self .filename , 'a' )
77
+ else :
78
+ self ._fh = open (self .filename , 'w' )
79
+ self ._fh .write (',' .join (cnames ) + '\n ' )
80
+
81
+ def record (self , point ):
82
+ """Record results of a sampling iteration.
83
+
84
+ Parameters
85
+ ----------
86
+ point : dict
87
+ Values mapped to variable names
88
+ """
89
+ vals = {}
90
+ for varname , value in zip (self .varnames , self .fn (point )):
91
+ vals [varname ] = value .ravel ()
92
+ columns = [str (val ) for var in self .varnames for val in vals [var ]]
93
+ self ._fh .write (',' .join (columns ) + '\n ' )
80
94
81
- def _dump_trace (name , trace ):
82
- """Dump a single-chain trace.
95
+ def close (self ):
96
+ self ._fh .close ()
97
+ self ._fh = None # Avoid serialization issue.
98
+
99
+ ## Selection methods
100
+
101
+ def _load_df (self ):
102
+ if self .df is None :
103
+ self .df = pd .read_csv (self .filename )
104
+
105
+ def __len__ (self ):
106
+ if self .filename is None :
107
+ return 0
108
+ self ._load_df ()
109
+ return self .df .shape [0 ]
110
+
111
+ def get_values (self , varname , burn = 0 , thin = 1 ):
112
+ """Get values from trace.
113
+
114
+ Parameters
115
+ ----------
116
+ varname : str
117
+ burn : int
118
+ thin : int
119
+
120
+ Returns
121
+ -------
122
+ A NumPy array
123
+ """
124
+ self ._load_df ()
125
+ var_df = self .df [self .flat_names [varname ]]
126
+ shape = (self .df .shape [0 ],) + self .var_shapes [varname ]
127
+ vals = var_df .values .ravel ().reshape (shape )
128
+ return vals [burn ::thin ]
129
+
130
+ def _slice (self , idx ):
131
+ warnings .warn ('Slice for Text backend has no effect.' )
132
+
133
+ def point (self , idx ):
134
+ """Return dictionary of point values at `idx` for current chain
135
+ with variables names as keys.
136
+ """
137
+ idx = int (idx )
138
+ self ._load_df ()
139
+ pt = {}
140
+ for varname in self .varnames :
141
+ vals = self .df [self .flat_names [varname ]].iloc [idx ]
142
+ pt [varname ] = vals .reshape (self .var_shapes [varname ])
143
+ return pt
144
+
145
+
146
+ def _create_flat_names (varname , shape ):
147
+ """Return flat variable names for `varname` of `shape`.
148
+
149
+ Examples
150
+ --------
151
+ >>> _create_flat_names('x', (5,))
152
+ ['x__0', 'x__1', 'x__2', 'x__3', 'x__4']
153
+
154
+ >>> _create_flat_names('x', (2, 2))
155
+ ['x__0_0', 'x__0_1', 'x__1_0', 'x__1_1']
83
156
"""
84
- chain_name = 'chain-{}' .format (trace .chain )
85
- chain_dir = os .path .join (name , chain_name )
86
- os .mkdir (chain_dir )
157
+ if not shape :
158
+ return [varname ]
159
+ labels = (np .ravel (xs ).tolist () for xs in np .indices (shape ))
160
+ labels = (map (str , xs ) for xs in labels )
161
+ return ['{}__{}' .format (varname , '_' .join (idxs )) for idxs in zip (* labels )]
87
162
88
- info = {}
89
- for varname in trace .varnames :
90
- data = trace .get_values (varname )
91
163
92
- if np . issubdtype ( data . dtype , np . int ):
93
- fmt = '%i'
94
- is_int = True
95
- else :
96
- fmt = '%g'
97
- is_int = False
98
- info [ varname ] = { 'shape' : data . shape , 'is_int' : is_int }
164
+ def _create_shape ( flat_names ):
165
+ "Determine shape from `_create_flat_names` output."
166
+ try :
167
+ _ , shape_str = flat_names [ - 1 ]. rsplit ( '__' , 1 )
168
+ except ValueError :
169
+ return ()
170
+ return tuple ( int ( i ) + 1 for i in shape_str . split ( '_' ))
99
171
100
- var_file = os .path .join (chain_dir , varname + '.txt' )
101
- np .savetxt (var_file , data .reshape (- 1 , data .size ), fmt = fmt )
102
- ## Store shape and dtype information for reloading.
103
- info_file = os .path .join (chain_dir , 'info.json' )
104
- with open (info_file , 'w' ) as sfh :
105
- json .dump (info , sfh )
106
172
107
-
108
- def load (name , chains = None , model = None ):
109
- """Load text database.
173
+ def load (name , model = None ):
174
+ """Load Text database.
110
175
111
176
Parameters
112
177
----------
113
178
name : str
114
- Path to root directory for text database
115
- chains : list
116
- Chains to load. If None, all chains are loaded.
179
+ Name of directory with files (one per chain)
117
180
model : Model
118
181
If None, the model is taken from the `with` context.
119
182
120
183
Returns
121
184
-------
122
- ndarray.Trace instance
185
+ A MultiTrace instance
123
186
"""
124
- chain_dirs = _get_chain_dirs (name )
125
- if chains is None :
126
- chains = list (chain_dirs .keys ())
187
+ files = glob (os .path .join (name , 'chain-*.csv' ))
127
188
128
189
traces = []
129
- for chain in chains :
130
- chain_dir = chain_dirs [chain ]
131
- info_file = os .path .join (chain_dir , 'info.json' )
132
- with open (info_file , 'r' ) as sfh :
133
- info = json .load (sfh )
134
- samples = {}
135
- for varname , info in info .items ():
136
- var_file = os .path .join (chain_dir , varname + '.txt' )
137
- dtype = int if info ['is_int' ] else float
138
- flat_data = np .loadtxt (var_file , dtype = dtype )
139
- samples [varname ] = flat_data .reshape (info ['shape' ])
140
- trace = NDArray (model = model )
141
- trace .samples = samples
190
+ for f in files :
191
+ chain = int (os .path .splitext (f )[0 ].rsplit ('-' , 1 )[1 ])
192
+ trace = Text (name , model = model )
142
193
trace .chain = chain
194
+ trace .filename = f
143
195
traces .append (trace )
144
196
return base .MultiTrace (traces )
145
197
146
198
147
- def _get_chain_dirs (name ):
148
- """Return mapping of chain number to directory."""
149
- return {_chain_dir_to_chain (chain_dir ): chain_dir
150
- for chain_dir in glob .glob (os .path .join (name , 'chain-*' ))}
199
+ def dump (name , trace , chains = None ):
200
+ """Store values from NDArray trace as CSV files.
201
+
202
+ Parameters
203
+ ----------
204
+ name : str
205
+ Name of directory to store CSV files in
206
+ trace : MultiTrace of NDArray traces
207
+ Result of MCMC run with default NDArray backend
208
+ chains : list
209
+ Chains to dump. If None, all chains are dumped.
210
+ """
211
+ if not os .path .exists (name ):
212
+ os .mkdir (name )
213
+ if chains is None :
214
+ chains = trace .chains
215
+
216
+ var_shapes = trace ._traces [chains [0 ]].var_shapes
217
+ flat_names = {v : _create_flat_names (v , shape )
218
+ for v , shape in var_shapes .items ()}
219
+
220
+ for chain in chains :
221
+ filename = os .path .join (name , 'chain-{}.csv' .format (chain ))
222
+ df = _trace_to_df (trace ._traces [chain ], flat_names )
223
+ df .to_csv (filename , index = False )
224
+
151
225
226
+ def _trace_to_df (trace , flat_names = None ):
227
+ """Convert single-chain trace to Pandas DataFrame.
152
228
153
- def _chain_dir_to_chain (chain_dir ):
154
- return int (os .path .basename (chain_dir ).split ('-' )[1 ])
229
+ Parameters
230
+ ----------
231
+ trace : NDarray trace
232
+ flat_names : dict or None
233
+ A dictionary that maps each variable name in `trace` to a list
234
+ of flat variable names (e.g., ['x__0', 'x__1', ...])
235
+ """
236
+ if flat_names is None :
237
+ flat_names = {v : _create_flat_names (v , shape )
238
+ for v , shape in trace .var_shapes .items ()}
239
+
240
+ var_dfs = []
241
+ for varname , shape in trace .var_shapes .items ():
242
+ vals = trace [varname ]
243
+ if len (shape ) == 1 :
244
+ flat_vals = vals
245
+ else :
246
+ flat_vals = vals .reshape (len (trace ), np .prod (shape ))
247
+ var_dfs .append (pd .DataFrame (flat_vals , columns = flat_names [varname ]))
248
+ return pd .concat (var_dfs , axis = 1 )
0 commit comments