@@ -77,7 +77,7 @@ def _set_sampler_vars(self, sampler_vars):
77
77
raise ValueError ("Backend does not support sampler stats." )
78
78
79
79
if self ._is_base_setup and self .sampler_vars != sampler_vars :
80
- raise ValueError ("Can't change sampler_vars" )
80
+ raise ValueError ("Can't change sampler_vars" )
81
81
82
82
if sampler_vars is None :
83
83
self .sampler_vars = None
@@ -311,8 +311,8 @@ def __getitem__(self, idx):
311
311
if var in self .varnames :
312
312
if var in self .stat_names :
313
313
warnings .warn ("Attribute access on a trace object is ambigous. "
314
- "Sampler statistic and model variable share a name. Use "
315
- "trace.get_values or trace.get_sampler_stats." )
314
+ "Sampler statistic and model variable share a name. Use "
315
+ "trace.get_values or trace.get_sampler_stats." )
316
316
return self .get_values (var , burn = burn , thin = thin )
317
317
if var in self .stat_names :
318
318
return self .get_sampler_stats (var , burn = burn , thin = thin )
@@ -331,8 +331,8 @@ def __getattr__(self, name):
331
331
if name in self .varnames :
332
332
if name in self .stat_names :
333
333
warnings .warn ("Attribute access on a trace object is ambigous. "
334
- "Sampler statistic and model variable share a name. Use "
335
- "trace.get_values or trace.get_sampler_stats." )
334
+ "Sampler statistic and model variable share a name. Use "
335
+ "trace.get_values or trace.get_sampler_stats." )
336
336
return self .get_values (name )
337
337
if name in self .stat_names :
338
338
return self .get_sampler_stats (name )
@@ -363,20 +363,28 @@ def stat_names(self):
363
363
names .update (vars .keys ())
364
364
return names
365
365
366
- def add_values (self , vals ):
367
- """add values to traces.
366
+ def add_values (self , vals , overwrite = False ):
367
+ """add variables to traces.
368
+
368
369
Parameters
369
370
----------
370
371
vals : dict (str: array-like)
371
- The keys should be the names of the new variables. The values are
372
- expected to be array-like object.
373
- For traces with more than one chain the length of each value
374
- should match the number of total samples already in the trace
375
- (chains * iterations), otherwise a warning is raised.
372
+ The keys should be the names of the new variables. The values are expected to be
373
+ array-like object. For traces with more than one chain the length of each value
374
+ should match the number of total samples already in the trace (chains * iterations),
375
+ otherwise a warning is raised.
376
+ overwrite : bool
377
+ If `False` (default) a ValueError is raised if the variable already exists.
378
+ Change to `True` to overwrite the values of variables
376
379
"""
377
380
for k , v in vals .items ():
381
+ new_var = 1
378
382
if k in self .varnames :
379
- raise ValueError ("Variable name {} already exists." .format (k ))
383
+ if overwrite :
384
+ self .varnames .remove (k )
385
+ new_var = 0
386
+ else :
387
+ raise ValueError ("Variable name {} already exists." .format (k ))
380
388
381
389
self .varnames .append (k )
382
390
@@ -392,9 +400,29 @@ def add_values(self, vals):
392
400
v = np .squeeze (v .reshape (len (chains ), len (self ), - 1 ))
393
401
394
402
for idx , chain in enumerate (chains .values ()):
403
+ if new_var :
404
+ dummy = tt .as_tensor_variable ([], k )
405
+ chain .vars .append (dummy )
395
406
chain .samples [k ] = v [idx ]
396
- dummy = tt .as_tensor_variable ([], k )
397
- chain .vars .append (dummy )
407
+
408
+ def remove_values (self , name ):
409
+ """remove variables from traces.
410
+
411
+ Parameters
412
+ ----------
413
+ name : str
414
+ Name of the variable to remove. Raises KeyError if the variable is not present
415
+ """
416
+ varnames = self .varnames
417
+ if name not in varnames :
418
+ raise KeyError ("Unknown variable {}" .format (name ))
419
+ self .varnames .remove (name )
420
+ chains = self ._straces
421
+ for chain in chains .values ():
422
+ for va in chain .vars :
423
+ if va .name == name :
424
+ chain .vars .remove (va )
425
+ del chain .samples [name ]
398
426
399
427
def get_values (self , varname , burn = 0 , thin = 1 , combine = True , chains = None ,
400
428
squeeze = True ):
0 commit comments