Skip to content

Commit 2da2c6d

Browse files
authored
Allow arbitrary order of plate (#555)
1 parent 0c129d2 commit 2da2c6d

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

numpyro/primitives.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,10 +265,7 @@ def process_message(self, msg):
265265
cond_indep_stack.append(frame)
266266
expected_shape = self._get_batch_shape(cond_indep_stack)
267267
dist_batch_shape = msg['fn'].batch_shape if msg['type'] == 'sample' else ()
268-
overlap_idx = len(expected_shape) - len(dist_batch_shape)
269-
if overlap_idx < 0:
270-
raise ValueError('Expected dimensions within plate = {}, which is less than the '
271-
'distribution\'s batch shape = {}.'.format(len(expected_shape), len(dist_batch_shape)))
268+
overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0)
272269
trailing_shape = expected_shape[overlap_idx:]
273270
# e.g. distribution with batch shape (1, 5) cannot be broadcast to (5, 5)
274271
broadcast_shape = lax.broadcast_shapes(trailing_shape, dist_batch_shape)

test/test_handlers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,16 @@ def model_nested_plates_2():
166166
assert xy.shape == (5, 1, 10)
167167

168168

169+
def model_nested_plates_3():
170+
outer = numpyro.plate('outer', 10, dim=-1)
171+
inner = numpyro.plate('inner', 5, dim=-2)
172+
numpyro.deterministic('z', 1.)
173+
174+
with inner, outer:
175+
xy = numpyro.sample('xy', dist.Normal(np.zeros((5, 10)), 1.))
176+
assert xy.shape == (5, 10)
177+
178+
169179
def model_dist_batch_shape():
170180
outer = numpyro.plate('outer', 10)
171181
inner = numpyro.plate('inner', 5, dim=-3)
@@ -204,6 +214,7 @@ def model_subsample_1():
204214
model_nested_plates_0,
205215
model_nested_plates_1,
206216
model_nested_plates_2,
217+
model_nested_plates_3,
207218
model_dist_batch_shape,
208219
model_subsample_1,
209220
])

0 commit comments

Comments
 (0)