Skip to content

Commit b0b79f7

Browse files
committed
Merge pull request #723 from pymc-devs/csv
Add CSV backend
2 parents 2868948 + 9c084dd commit b0b79f7

File tree

7 files changed

+303
-109
lines changed

7 files changed

+303
-109
lines changed

pymc3/backends/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
2. Text files (pymc3.backends.Text)
88
3. SQLite (pymc3.backends.SQLite)
99
10-
The NumPy arrays and text files both hold the entire trace in memory,
11-
whereas SQLite commits the trace to the database while sampling.
10+
The NDArray backend holds the entire trace in memory, whereas the Text
11+
and SQLite backends store the values while sampling.
1212
1313
Selecting a backend
1414
-------------------

pymc3/backends/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
from ..model import modelcontext
88

99

10+
class BackendError(Exception):
11+
pass
12+
13+
1014
class BaseTrace(object):
1115
"""Base trace object
1216

pymc3/backends/sqlite.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ def close(self):
272272
self.connected = False
273273

274274

275+
# TODO Consider merging `_create_colnames` and `_create_shape` with
276+
# very similar functions in the csv backend.
275277
def _create_colnames(shape):
276278
"""Return column names based on `shape`.
277279

pymc3/backends/text.py

Lines changed: 195 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,31 @@
11
"""Text file trace backend
22
3-
After sampling with NDArray backend, save results as text files.
3+
Store sampling values as CSV files.
44
5-
As this other backends, this can be used by passing the backend instance
6-
to `sample`.
5+
File format
6+
-----------
77
8-
>>> import pymc3 as pm
9-
>>> db = pm.backends.Text('test')
10-
>>> trace = pm.sample(..., trace=db)
8+
Sampling values for each chain are saved in a separate file (under a
9+
directory specified by the `name` argument). The rows correspond to
10+
sampling iterations. The column names consist of variable names and
11+
index labels. For example, the heading
1112
12-
Or sampling can be performed with the default NDArray backend and then
13-
dumped to text files after.
13+
x,y__0_0,y__0_1,y__1_0,y__1_1,y__2_0,y__2_1
1414
15-
>>> from pymc3.backends import text
16-
>>> trace = pm.sample(...)
17-
>>> text.dump('test', trace)
18-
19-
Database format
20-
---------------
21-
22-
For each chain, a directory named `chain-N` is created. In this
23-
directory, one file per variable is created containing the values of the
24-
object. To deal with multidimensional variables, the array is reshaped
25-
to one dimension before saving with `numpy.savetxt`. The shape and dtype
26-
information is saved in a json file in the same directory and is used to
27-
load the database back again using `numpy.loadtxt`.
15+
represents two variables, x and y, where x is a scalar and y has a
16+
shape of (3, 2).
2817
"""
29-
import os
30-
import glob
31-
import json
18+
from glob import glob
3219
import numpy as np
20+
import os
21+
import pandas as pd
22+
import warnings
3323

3424
from ..backends import base
35-
from ..backends.ndarray import NDArray
3625

3726

38-
class Text(NDArray):
39-
"""Text storage
27+
class Text(base.BaseTrace):
28+
"""Text trace object
4029
4130
Parameters
4231
----------
@@ -53,102 +42,207 @@ def __init__(self, name, model=None, vars=None):
5342
os.mkdir(name)
5443
super(Text, self).__init__(name, model, vars)
5544

56-
def close(self):
57-
super(Text, self).close()
58-
_dump_trace(self.name, self)
59-
60-
61-
def dump(name, trace, chains=None):
62-
"""Store NDArray trace as text database.
63-
64-
Parameters
65-
----------
66-
name : str
67-
Name of directory to store text files
68-
trace : MultiTrace of NDArray traces
69-
Result of MCMC run with default NDArray backend
70-
chains : list
71-
Chains to dump. If None, all chains are dumped.
72-
"""
73-
if not os.path.exists(name):
74-
os.mkdir(name)
75-
if chains is None:
76-
chains = trace.chains
77-
for chain in chains:
78-
_dump_trace(name, trace._traces[chain])
79-
45+
self.flat_names = {v: _create_flat_names(v, shape)
46+
for v, shape in self.var_shapes.items()}
47+
48+
self.filename = None
49+
self._fh = None
50+
self.df = None
51+
52+
## Sampling methods
53+
54+
def setup(self, draws, chain):
55+
"""Perform chain-specific setup.
56+
57+
Parameters
58+
----------
59+
draws : int
60+
Expected number of draws
61+
chain : int
62+
Chain number
63+
"""
64+
self.chain = chain
65+
self.filename = os.path.join(self.name, 'chain-{}.csv'.format(chain))
66+
67+
cnames = [fv for v in self.varnames for fv in self.flat_names[v]]
68+
69+
if os.path.exists(self.filename):
70+
with open(self.filename) as fh:
71+
prev_cnames = next(fh).strip().split(',')
72+
if prev_cnames != cnames:
73+
raise base.BackendError(
74+
"Previous file '{}' has different variables names "
75+
"than current model.".format(self.filename))
76+
self._fh = open(self.filename, 'a')
77+
else:
78+
self._fh = open(self.filename, 'w')
79+
self._fh.write(','.join(cnames) + '\n')
80+
81+
def record(self, point):
82+
"""Record results of a sampling iteration.
83+
84+
Parameters
85+
----------
86+
point : dict
87+
Values mapped to variable names
88+
"""
89+
vals = {}
90+
for varname, value in zip(self.varnames, self.fn(point)):
91+
vals[varname] = value.ravel()
92+
columns = [str(val) for var in self.varnames for val in vals[var]]
93+
self._fh.write(','.join(columns) + '\n')
8094

81-
def _dump_trace(name, trace):
82-
"""Dump a single-chain trace.
95+
def close(self):
96+
self._fh.close()
97+
self._fh = None # Avoid serialization issue.
98+
99+
## Selection methods
100+
101+
def _load_df(self):
102+
if self.df is None:
103+
self.df = pd.read_csv(self.filename)
104+
105+
def __len__(self):
106+
if self.filename is None:
107+
return 0
108+
self._load_df()
109+
return self.df.shape[0]
110+
111+
def get_values(self, varname, burn=0, thin=1):
112+
"""Get values from trace.
113+
114+
Parameters
115+
----------
116+
varname : str
117+
burn : int
118+
thin : int
119+
120+
Returns
121+
-------
122+
A NumPy array
123+
"""
124+
self._load_df()
125+
var_df = self.df[self.flat_names[varname]]
126+
shape = (self.df.shape[0],) + self.var_shapes[varname]
127+
vals = var_df.values.ravel().reshape(shape)
128+
return vals[burn::thin]
129+
130+
def _slice(self, idx):
131+
warnings.warn('Slice for Text backend has no effect.')
132+
133+
def point(self, idx):
134+
"""Return dictionary of point values at `idx` for current chain
135+
with variables names as keys.
136+
"""
137+
idx = int(idx)
138+
self._load_df()
139+
pt = {}
140+
for varname in self.varnames:
141+
vals = self.df[self.flat_names[varname]].iloc[idx]
142+
pt[varname] = vals.reshape(self.var_shapes[varname])
143+
return pt
144+
145+
146+
def _create_flat_names(varname, shape):
147+
"""Return flat variable names for `varname` of `shape`.
148+
149+
Examples
150+
--------
151+
>>> _create_flat_names('x', (5,))
152+
['x__0', 'x__1', 'x__2', 'x__3', 'x__4']
153+
154+
>>> _create_flat_names('x', (2, 2))
155+
['x__0_0', 'x__0_1', 'x__1_0', 'x__1_1']
83156
"""
84-
chain_name = 'chain-{}'.format(trace.chain)
85-
chain_dir = os.path.join(name, chain_name)
86-
os.mkdir(chain_dir)
157+
if not shape:
158+
return [varname]
159+
labels = (np.ravel(xs).tolist() for xs in np.indices(shape))
160+
labels = (map(str, xs) for xs in labels)
161+
return ['{}__{}'.format(varname, '_'.join(idxs)) for idxs in zip(*labels)]
87162

88-
info = {}
89-
for varname in trace.varnames:
90-
data = trace.get_values(varname)
91163

92-
if np.issubdtype(data.dtype, np.int):
93-
fmt = '%i'
94-
is_int = True
95-
else:
96-
fmt = '%g'
97-
is_int = False
98-
info[varname] = {'shape': data.shape, 'is_int': is_int}
164+
def _create_shape(flat_names):
165+
"Determine shape from `_create_flat_names` output."
166+
try:
167+
_, shape_str = flat_names[-1].rsplit('__', 1)
168+
except ValueError:
169+
return ()
170+
return tuple(int(i) + 1 for i in shape_str.split('_'))
99171

100-
var_file = os.path.join(chain_dir, varname + '.txt')
101-
np.savetxt(var_file, data.reshape(-1, data.size), fmt=fmt)
102-
## Store shape and dtype information for reloading.
103-
info_file = os.path.join(chain_dir, 'info.json')
104-
with open(info_file, 'w') as sfh:
105-
json.dump(info, sfh)
106172

107-
108-
def load(name, chains=None, model=None):
109-
"""Load text database.
173+
def load(name, model=None):
174+
"""Load Text database.
110175
111176
Parameters
112177
----------
113178
name : str
114-
Path to root directory for text database
115-
chains : list
116-
Chains to load. If None, all chains are loaded.
179+
Name of directory with files (one per chain)
117180
model : Model
118181
If None, the model is taken from the `with` context.
119182
120183
Returns
121184
-------
122-
ndarray.Trace instance
185+
A MultiTrace instance
123186
"""
124-
chain_dirs = _get_chain_dirs(name)
125-
if chains is None:
126-
chains = list(chain_dirs.keys())
187+
files = glob(os.path.join(name, 'chain-*.csv'))
127188

128189
traces = []
129-
for chain in chains:
130-
chain_dir = chain_dirs[chain]
131-
info_file = os.path.join(chain_dir, 'info.json')
132-
with open(info_file, 'r') as sfh:
133-
info = json.load(sfh)
134-
samples = {}
135-
for varname, info in info.items():
136-
var_file = os.path.join(chain_dir, varname + '.txt')
137-
dtype = int if info['is_int'] else float
138-
flat_data = np.loadtxt(var_file, dtype=dtype)
139-
samples[varname] = flat_data.reshape(info['shape'])
140-
trace = NDArray(model=model)
141-
trace.samples = samples
190+
for f in files:
191+
chain = int(os.path.splitext(f)[0].rsplit('-', 1)[1])
192+
trace = Text(name, model=model)
142193
trace.chain = chain
194+
trace.filename = f
143195
traces.append(trace)
144196
return base.MultiTrace(traces)
145197

146198

147-
def _get_chain_dirs(name):
148-
"""Return mapping of chain number to directory."""
149-
return {_chain_dir_to_chain(chain_dir): chain_dir
150-
for chain_dir in glob.glob(os.path.join(name, 'chain-*'))}
199+
def dump(name, trace, chains=None):
200+
"""Store values from NDArray trace as CSV files.
201+
202+
Parameters
203+
----------
204+
name : str
205+
Name of directory to store CSV files in
206+
trace : MultiTrace of NDArray traces
207+
Result of MCMC run with default NDArray backend
208+
chains : list
209+
Chains to dump. If None, all chains are dumped.
210+
"""
211+
if not os.path.exists(name):
212+
os.mkdir(name)
213+
if chains is None:
214+
chains = trace.chains
215+
216+
var_shapes = trace._traces[chains[0]].var_shapes
217+
flat_names = {v: _create_flat_names(v, shape)
218+
for v, shape in var_shapes.items()}
219+
220+
for chain in chains:
221+
filename = os.path.join(name, 'chain-{}.csv'.format(chain))
222+
df = _trace_to_df(trace._traces[chain], flat_names)
223+
df.to_csv(filename, index=False)
224+
151225

226+
def _trace_to_df(trace, flat_names=None):
227+
"""Convert single-chain trace to Pandas DataFrame.
152228
153-
def _chain_dir_to_chain(chain_dir):
154-
return int(os.path.basename(chain_dir).split('-')[1])
229+
Parameters
230+
----------
231+
trace : NDarray trace
232+
flat_names : dict or None
233+
A dictionary that maps each variable name in `trace` to a list
234+
of flat variable names (e.g., ['x__0', 'x__1', ...])
235+
"""
236+
if flat_names is None:
237+
flat_names = {v: _create_flat_names(v, shape)
238+
for v, shape in trace.var_shapes.items()}
239+
240+
var_dfs = []
241+
for varname, shape in trace.var_shapes.items():
242+
vals = trace[varname]
243+
if len(shape) == 1:
244+
flat_vals = vals
245+
else:
246+
flat_vals = vals.reshape(len(trace), np.prod(shape))
247+
var_dfs.append(pd.DataFrame(flat_vals, columns=flat_names[varname]))
248+
return pd.concat(var_dfs, axis=1)

0 commit comments

Comments
 (0)