Skip to content

Commit 2394fac

Browse files
aloctavodiaspringcoil
authored andcommitted
add varnames argument for consintency (#1641)
* add varnames argument for consintency * do not pass flat_names argument (or varnames)
1 parent 19fcf4b commit 2394fac

File tree

2 files changed

+21
-17
lines changed

2 files changed

+21
-17
lines changed

pymc3/backends/text.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,8 @@ def dump(name, trace, chains=None):
192192
if chains is None:
193193
chains = trace.chains
194194

195-
var_shapes = trace._straces[chains[0]].var_shapes
196-
flat_names = {v: ttab.create_flat_names(v, shape)
197-
for v, shape in var_shapes.items()}
198-
199195
for chain in chains:
200196
filename = os.path.join(name, 'chain-{}.csv'.format(chain))
201197
df = ttab.trace_to_dataframe(
202-
trace, chains=chain, flat_names=flat_names,hide_transformed_vars=False)
198+
trace, chains=chain, hide_transformed_vars=False)
203199
df.to_csv(filename, index=False)

pymc3/backends/tracetab.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
__all__ = ['trace_to_dataframe']
88

99

10-
def trace_to_dataframe(trace, chains=None, flat_names=None, hide_transformed_vars=True):
10+
def trace_to_dataframe(trace, chains=None, varnames=None, hide_transformed_vars=True):
1111
"""Convert trace to Pandas DataFrame.
1212
1313
Parameters
@@ -16,21 +16,29 @@ def trace_to_dataframe(trace, chains=None, flat_names=None, hide_transformed_var
1616
chains : int or list of ints
1717
Chains to include. If None, all chains are used. A single
1818
chain value can also be given.
19-
flat_names : dict or None
20-
A dictionary that maps each variable name in `trace` to a list
19+
varnames : list of variable names
20+
Variables to be included in the DataFrame, if None all variable are
21+
included.
22+
hide_transformed_vars: boolean
23+
If true transformed variables will not be included in the resulting
24+
DataFrame.
2125
"""
2226
var_shapes = trace._straces[0].var_shapes
23-
if flat_names is None:
24-
flat_names = {v: create_flat_names(v, shape)
25-
for v, shape in var_shapes.items()
26-
if not (hide_transformed_vars and v.endswith('_'))}
27+
28+
if varnames is None:
29+
varnames = var_shapes.keys()
2730

31+
flat_names = {v: create_flat_names(v, shape)
32+
for v, shape in var_shapes.items()
33+
if not (hide_transformed_vars and v.endswith('_'))}
34+
2835
var_dfs = []
29-
for varname, shape in var_shapes.items():
30-
if not hide_transformed_vars or not varname.endswith('_'):
31-
vals = trace.get_values(varname, combine=True, chains=chains)
32-
flat_vals = vals.reshape(vals.shape[0], -1)
33-
var_dfs.append(pd.DataFrame(flat_vals, columns=flat_names[varname]))
36+
for v, shape in var_shapes.items():
37+
if v in varnames:
38+
if not hide_transformed_vars or not v.endswith('_'):
39+
vals = trace.get_values(v, combine=True, chains=chains)
40+
flat_vals = vals.reshape(vals.shape[0], -1)
41+
var_dfs.append(pd.DataFrame(flat_vals, columns=flat_names[v]))
3442
return pd.concat(var_dfs, axis=1)
3543

3644

0 commit comments

Comments
 (0)