@@ -229,23 +229,34 @@ def __init__(self, batch_shape=(), event_shape=(), *, validate_args=None):
229229 if validate_args is not None :
230230 self ._validate_args = validate_args
231231 if self ._validate_args :
232- for param , constraint in self .arg_constraints .items ():
233- if param not in self .__dict__ and isinstance (
234- getattr (type (self ), param ), lazy_property
235- ):
236- continue
237- if constraints .is_dependent (constraint ):
238- continue # skip constraints that cannot be checked
239- is_valid = constraint (getattr (self , param ))
240- if not_jax_tracer (is_valid ):
241- if not np .all (is_valid ):
242- raise ValueError (
243- "{} distribution got invalid {} parameter." .format (
244- self .__class__ .__name__ , param
245- )
246- )
232+ self .validate_args (strict = False )
247233 super (Distribution , self ).__init__ ()
248234
235+ def validate_args (self , strict : bool = True ) -> None :
236+ """
237+ Validate the arguments of the distribution.
238+
239+ :param strict: Require strict validation, raising an error if the function is
240+ called inside jitted code.
241+ """
242+ for param , constraint in self .arg_constraints .items ():
243+ if param not in self .__dict__ and isinstance (
244+ getattr (type (self ), param ), lazy_property
245+ ):
246+ continue
247+ if constraints .is_dependent (constraint ):
248+ continue # skip constraints that cannot be checked
249+ is_valid = constraint (getattr (self , param ))
250+ if not_jax_tracer (is_valid ):
251+ if not np .all (is_valid ):
252+ raise ValueError (
253+ "{} distribution got invalid {} parameter." .format (
254+ self .__class__ .__name__ , param
255+ )
256+ )
257+ elif strict :
258+ raise RuntimeError ("Cannot validate arguments inside jitted code." )
259+
249260 @property
250261 def batch_shape (self ):
251262 """
0 commit comments