Skip to content

Commit 06d4415

Browse files
author
Junpeng Lao
authored
Fix potential type conflict in observedRV (#3067)
* fix potential type conflict in observedRV For the discussion see https://discourse.pymc.io/t/shapes-of-shared-variables/1434/ * better type checking * fix test float32
1 parent 6fd230f commit 06d4415

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

pymc3/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,6 +1293,10 @@ def __init__(self, type=None, owner=None, index=None, name=None, data=None,
12931293
needed for upscaling logp
12941294
"""
12951295
from .distributions import TensorType
1296+
1297+
if hasattr(data, 'type') and isinstance(data.type, tt.TensorType):
1298+
type = data.type
1299+
12961300
if type is None:
12971301
data = pandas_to_array(data)
12981302
type = TensorType(distribution.dtype, data.shape)

pymc3/tests/test_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,16 @@ def test_observed_rv_fail(self):
130130
x = Normal('x')
131131
Normal('n', observed=x)
132132

133+
def test_observed_type(self):
134+
X_ = np.random.randn(100, 5)
135+
X = pm.floatX(theano.shared(X_))
136+
with pm.Model():
137+
x1 = pm.Normal('x1', observed=X_)
138+
x2 = pm.Normal('x2', observed=X)
139+
140+
assert x1.type == X.type
141+
assert x2.type == X.type
142+
133143

134144
class TestTheanoConfig(object):
135145
def test_set_testval_raise(self):

0 commit comments

Comments
 (0)