7
7
__all__ = ['trace_to_dataframe' ]
8
8
9
9
10
- def trace_to_dataframe (trace , chains = None , flat_names = None , hide_transformed_vars = True ):
10
+ def trace_to_dataframe (trace , chains = None , varnames = None , hide_transformed_vars = True ):
11
11
"""Convert trace to Pandas DataFrame.
12
12
13
13
Parameters
@@ -16,21 +16,29 @@ def trace_to_dataframe(trace, chains=None, flat_names=None, hide_transformed_var
16
16
chains : int or list of ints
17
17
Chains to include. If None, all chains are used. A single
18
18
chain value can also be given.
19
- flat_names : dict or None
20
- A dictionary that maps each variable name in `trace` to a list
19
+ varnames : list of variable names
20
+ Variables to be included in the DataFrame, if None all variable are
21
+ included.
22
+ hide_transformed_vars: boolean
23
+ If true transformed variables will not be included in the resulting
24
+ DataFrame.
21
25
"""
22
26
var_shapes = trace ._straces [0 ].var_shapes
23
- if flat_names is None :
24
- flat_names = {v : create_flat_names (v , shape )
25
- for v , shape in var_shapes .items ()
26
- if not (hide_transformed_vars and v .endswith ('_' ))}
27
+
28
+ if varnames is None :
29
+ varnames = var_shapes .keys ()
27
30
31
+ flat_names = {v : create_flat_names (v , shape )
32
+ for v , shape in var_shapes .items ()
33
+ if not (hide_transformed_vars and v .endswith ('_' ))}
34
+
28
35
var_dfs = []
29
- for varname , shape in var_shapes .items ():
30
- if not hide_transformed_vars or not varname .endswith ('_' ):
31
- vals = trace .get_values (varname , combine = True , chains = chains )
32
- flat_vals = vals .reshape (vals .shape [0 ], - 1 )
33
- var_dfs .append (pd .DataFrame (flat_vals , columns = flat_names [varname ]))
36
+ for v , shape in var_shapes .items ():
37
+ if v in varnames :
38
+ if not hide_transformed_vars or not v .endswith ('_' ):
39
+ vals = trace .get_values (v , combine = True , chains = chains )
40
+ flat_vals = vals .reshape (vals .shape [0 ], - 1 )
41
+ var_dfs .append (pd .DataFrame (flat_vals , columns = flat_names [v ]))
34
42
return pd .concat (var_dfs , axis = 1 )
35
43
36
44
0 commit comments