Skip to content

Commit d7d801f

Browse files
committed
Merge branch 'master' into release_3.4.2
2 parents a8d94d7 + b752b07 commit d7d801f

File tree

2 files changed

+47
-16
lines changed

2 files changed

+47
-16
lines changed

pymc3/backends/base.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _set_sampler_vars(self, sampler_vars):
7777
raise ValueError("Backend does not support sampler stats.")
7878

7979
if self._is_base_setup and self.sampler_vars != sampler_vars:
80-
raise ValueError("Can't change sampler_vars")
80+
raise ValueError("Can't change sampler_vars")
8181

8282
if sampler_vars is None:
8383
self.sampler_vars = None
@@ -311,8 +311,8 @@ def __getitem__(self, idx):
311311
if var in self.varnames:
312312
if var in self.stat_names:
313313
warnings.warn("Attribute access on a trace object is ambigous. "
314-
"Sampler statistic and model variable share a name. Use "
315-
"trace.get_values or trace.get_sampler_stats.")
314+
"Sampler statistic and model variable share a name. Use "
315+
"trace.get_values or trace.get_sampler_stats.")
316316
return self.get_values(var, burn=burn, thin=thin)
317317
if var in self.stat_names:
318318
return self.get_sampler_stats(var, burn=burn, thin=thin)
@@ -331,8 +331,8 @@ def __getattr__(self, name):
331331
if name in self.varnames:
332332
if name in self.stat_names:
333333
warnings.warn("Attribute access on a trace object is ambigous. "
334-
"Sampler statistic and model variable share a name. Use "
335-
"trace.get_values or trace.get_sampler_stats.")
334+
"Sampler statistic and model variable share a name. Use "
335+
"trace.get_values or trace.get_sampler_stats.")
336336
return self.get_values(name)
337337
if name in self.stat_names:
338338
return self.get_sampler_stats(name)
@@ -363,20 +363,28 @@ def stat_names(self):
363363
names.update(vars.keys())
364364
return names
365365

366-
def add_values(self, vals):
367-
"""add values to traces.
366+
def add_values(self, vals, overwrite=False):
367+
"""add variables to traces.
368+
368369
Parameters
369370
----------
370371
vals : dict (str: array-like)
371-
The keys should be the names of the new variables. The values are
372-
expected to be array-like object.
373-
For traces with more than one chain the length of each value
374-
should match the number of total samples already in the trace
375-
(chains * iterations), otherwise a warning is raised.
372+
The keys should be the names of the new variables. The values are expected to be
373+
array-like object. For traces with more than one chain the length of each value
374+
should match the number of total samples already in the trace (chains * iterations),
375+
otherwise a warning is raised.
376+
overwrite : bool
377+
If `False` (default) a ValueError is raised if the variable already exists.
378+
Change to `True` to overwrite the values of variables
376379
"""
377380
for k, v in vals.items():
381+
new_var = 1
378382
if k in self.varnames:
379-
raise ValueError("Variable name {} already exists.".format(k))
383+
if overwrite:
384+
self.varnames.remove(k)
385+
new_var = 0
386+
else:
387+
raise ValueError("Variable name {} already exists.".format(k))
380388

381389
self.varnames.append(k)
382390

@@ -392,9 +400,29 @@ def add_values(self, vals):
392400
v = np.squeeze(v.reshape(len(chains), len(self), -1))
393401

394402
for idx, chain in enumerate(chains.values()):
403+
if new_var:
404+
dummy = tt.as_tensor_variable([], k)
405+
chain.vars.append(dummy)
395406
chain.samples[k] = v[idx]
396-
dummy = tt.as_tensor_variable([], k)
397-
chain.vars.append(dummy)
407+
408+
def remove_values(self, name):
409+
"""remove variables from traces.
410+
411+
Parameters
412+
----------
413+
name : str
414+
Name of the variable to remove. Raises KeyError if the variable is not present
415+
"""
416+
varnames = self.varnames
417+
if name not in varnames:
418+
raise KeyError("Unknown variable {}".format(name))
419+
self.varnames.remove(name)
420+
chains = self._straces
421+
for chain in chains.values():
422+
for va in chain.vars:
423+
if va.name == name:
424+
chain.vars.remove(va)
425+
del chain.samples[name]
398426

399427
def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
400428
squeeze=True):

pymc3/tests/test_ndarray_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_merge_traces_nonunique(self):
120120
base.merge_traces([mtrace0, mtrace1])
121121

122122

123-
class TestMultiTrace_add_values(bf.ModelBackendSampledTestCase):
123+
class TestMultiTrace_add_remove_values(bf.ModelBackendSampledTestCase):
124124
name = None
125125
backend = ndarray.NDArray
126126
shape = ()
@@ -134,6 +134,9 @@ def test_add_values(self):
134134
assert len(orig_varnames) == len(mtrace.varnames) - 1
135135
assert name in mtrace.varnames
136136
assert np.all(mtrace[orig_varnames[0]] == mtrace[name])
137+
mtrace.remove_values(name)
138+
assert len(orig_varnames) == len(mtrace.varnames)
139+
assert name not in mtrace.varnames
137140

138141

139142
class TestSqueezeCat(object):

0 commit comments

Comments
 (0)