Skip to content

Commit 8e9313f

Browse files
Refactor validate_args method on distributions (fixes #1865). (#1866)
1 parent 94f4b99 commit 8e9313f

File tree

2 files changed

+45
-15
lines changed

2 files changed

+45
-15
lines changed

numpyro/distributions/distribution.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -229,23 +229,34 @@ def __init__(self, batch_shape=(), event_shape=(), *, validate_args=None):
229229
if validate_args is not None:
230230
self._validate_args = validate_args
231231
if self._validate_args:
232-
for param, constraint in self.arg_constraints.items():
233-
if param not in self.__dict__ and isinstance(
234-
getattr(type(self), param), lazy_property
235-
):
236-
continue
237-
if constraints.is_dependent(constraint):
238-
continue # skip constraints that cannot be checked
239-
is_valid = constraint(getattr(self, param))
240-
if not_jax_tracer(is_valid):
241-
if not np.all(is_valid):
242-
raise ValueError(
243-
"{} distribution got invalid {} parameter.".format(
244-
self.__class__.__name__, param
245-
)
246-
)
232+
self.validate_args(strict=False)
247233
super(Distribution, self).__init__()
248234

235+
def validate_args(self, strict: bool = True) -> None:
236+
"""
237+
Validate the arguments of the distribution.
238+
239+
:param strict: Require strict validation, raising an error if the function is
240+
called inside jitted code.
241+
"""
242+
for param, constraint in self.arg_constraints.items():
243+
if param not in self.__dict__ and isinstance(
244+
getattr(type(self), param), lazy_property
245+
):
246+
continue
247+
if constraints.is_dependent(constraint):
248+
continue # skip constraints that cannot be checked
249+
is_valid = constraint(getattr(self, param))
250+
if not_jax_tracer(is_valid):
251+
if not np.all(is_valid):
252+
raise ValueError(
253+
"{} distribution got invalid {} parameter.".format(
254+
self.__class__.__name__, param
255+
)
256+
)
257+
elif strict:
258+
raise RuntimeError("Cannot validate arguments inside jitted code.")
259+
249260
@property
250261
def batch_shape(self):
251262
"""

test/test_distributions.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3322,6 +3322,25 @@ def test_vmap_validate_args():
33223322
assert not v_dist._validate_args
33233323

33243324

3325+
def test_explicit_validate_args():
3326+
# Check validation passes for valid parameters.
3327+
d = dist.Normal(0, 1)
3328+
d.validate_args()
3329+
3330+
# Check validation fails for invalid parameters.
3331+
d = dist.Normal(0, -1)
3332+
with pytest.raises(ValueError, match="got invalid scale parameter"):
3333+
d.validate_args()
3334+
3335+
# Check validation is skipped for strict=False and raises an error for strict=True.
3336+
jitted = jax.jit(
3337+
lambda d, strict: d.validate_args(strict), static_argnames=["strict"]
3338+
)
3339+
jitted(d, False)
3340+
with pytest.raises(RuntimeError, match="Cannot validate arguments"):
3341+
jitted(d, True)
3342+
3343+
33253344
def test_multinomial_abstract_total_count():
33263345
probs = jnp.array([0.2, 0.5, 0.3])
33273346
key = random.PRNGKey(0)

0 commit comments

Comments
 (0)