Skip to content

Commit 10548fe

Browse files
committed
iterating of support_sizes
1 parent e14eea7 commit 10548fe

File tree

1 file changed

+2
-10
lines changed

1 file changed

+2
-10
lines changed

numpyro/infer/hmc_gibbs.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -476,16 +476,8 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
476476
# Each support is padded with zeros to have the same length
477477
# ravel is used to maintain a consistant behaviour with `support_sizes`
478478

479-
max_length_support_enumerates = max(
480-
(
481-
site["fn"].enumerate_support(False).shape[0]
482-
for site in self._prototype_trace.values()
483-
if site["type"] == "sample"
484-
and site["fn"].has_enumerate_support
485-
and not site["is_observed"]
486-
)
487-
)
488-
479+
max_length_support_enumerates = max(size for size in self._support_sizes.values())
480+
489481
support_enumerates = {}
490482
for name, support_size in self._support_sizes.items():
491483
site = self._prototype_trace[name]

0 commit comments

Comments
 (0)