Skip to content

Commit f4b9d99

Browse files
committed
fixed lint issues
1 parent 10548fe commit f4b9d99

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

numpyro/infer/hmc_gibbs.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import copy
66
from functools import partial
77

8-
import jax
98
import numpy as np
109

10+
import jax
1111
from jax import device_put, grad, jacfwd, random, value_and_grad
1212
from jax.flatten_util import ravel_pytree
1313
import jax.numpy as jnp
@@ -476,7 +476,9 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
476476
# Each support is padded with zeros to have the same length
477477
# ravel is used to maintain a consistant behaviour with `support_sizes`
478478

479-
max_length_support_enumerates = max(size for size in self._support_sizes.values())
479+
max_length_support_enumerates = max(
480+
size for size in self._support_sizes.values()
481+
)
480482

481483
support_enumerates = {}
482484
for name, support_size in self._support_sizes.items():
@@ -493,7 +495,7 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
493495
support_enumerates[name] = padded_enumerate_support
494496

495497
self._support_enumerates = jax.vmap(
496-
lambda x: ravel_pytree(x)[0] , in_axes=0, out_axes=1
498+
lambda x: ravel_pytree(x)[0], in_axes=0, out_axes=1
497499
)(support_enumerates)
498500

499501
self._gibbs_sites = [

numpyro/infer/mixed_hmc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from jax import grad, jacfwd, lax, random
88
from jax.flatten_util import ravel_pytree
9-
import jax
109
import jax.numpy as jnp
1110

1211
from numpyro.infer.hmc import momentum_generator

0 commit comments

Comments
 (0)