Skip to content

Commit 6daf11b

Browse files
author
Junpeng Lao
authored
Merge pull request #2281 from aseyboldt/discrete-trafo
Disallow transformations for discrete variables
2 parents 4071c7d + df554dc commit 6daf11b

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

pymc3/distributions/distribution.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def getattr_value(self, val):
8787

8888
def _repr_latex_(self, name=None, dist=None):
8989
return None
90-
90+
9191

9292
def TensorType(dtype, shape):
9393
return tt.TensorType(str(dtype), np.atleast_1d(shape) == 1)
@@ -123,6 +123,11 @@ def __init__(self, shape=(), dtype=None, defaults=('mode', ),
123123
dtype = 'int64'
124124
if dtype != 'int16' and dtype != 'int64':
125125
raise TypeError('Discrete classes expect dtype to be int16 or int64.')
126+
127+
if kwargs.get('transform', None) is not None:
128+
raise ValueError("Transformations for discrete distributions "
129+
"are not allowed.")
130+
126131
super(Discrete, self).__init__(
127132
shape, dtype, defaults=defaults, *args, **kwargs)
128133

pymc3/tests/test_distributions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,3 +860,13 @@ def test_repr_latex_():
860860
assert x2._repr_latex_()=='$Timeseries \\sim \\text{GaussianRandomWalk}(\\mathit{mu}=Continuous, \\mathit{sd}=1.0)$'
861861
assert x3._repr_latex_()=='$Multivariate \\sim \\text{MvStudentT}(\\mathit{nu}=5, \\mathit{mu}=Timeseries, \\mathit{Sigma}=array)$'
862862
assert x4._repr_latex_()=='$Mixture \\sim \\text{NormalMixture}(\\mathit{w}=array, \\mathit{mu}=Multivariate, \\mathit{sigma}=f(Discrete))$'
863+
864+
865+
def test_discrete_trafo():
866+
with pytest.raises(ValueError) as err:
867+
Binomial.dist(n=5, p=0.5, transform='log')
868+
err.match('Transformations for discrete distributions')
869+
with Model():
870+
with pytest.raises(ValueError) as err:
871+
Binomial('a', n=5, p=0.5, transform='log')
872+
err.match('Transformations for discrete distributions')

0 commit comments

Comments
 (0)