@@ -386,50 +386,45 @@ def numba_funcify_RandomVariable_core(op: RandomVariable, **kwargs):
386386 )
387387
388388
389- @ numba_funcify . register
390- def numba_funcify_RandomVariable ( op : RandomVariableWithCoreShape , node , ** kwargs ):
391- core_shape = node . inputs [ 0 ]
389+ def rv_fallback_impl ( op : RandomVariableWithCoreShape , node ):
390+ """Create a fallback implementation for random variables using object mode."""
391+ import warnings
392392
393393 [rv_node ] = op .fgraph .apply_nodes
394394 rv_op : RandomVariable = rv_node .op
395+
396+ warnings .warn (
397+ f"Numba will use object mode to execute the random variable { rv_op .name } " ,
398+ UserWarning ,
399+ )
400+
395401 size = rv_op .size_param (rv_node )
396- dist_params = rv_op .dist_params (rv_node )
397402 size_len = None if isinstance (size .type , NoneTypeT ) else get_vector_length (size )
398- core_shape_len = get_vector_length (core_shape )
399403 inplace = rv_op .inplace
400404
401- core_rv_fn = numba_core_rv_funcify (rv_op , rv_node )
402- nin = 1 + len (dist_params ) # rng + params
403- core_op_fn = store_core_outputs (core_rv_fn , nin = nin , nout = 1 )
404-
405- batch_ndim = rv_op .batch_ndim (rv_node )
406-
407- # numba doesn't support nested literals right now...
408- input_bc_patterns = encode_literals (
409- tuple (input_var .type .broadcastable [:batch_ndim ] for input_var in dist_params )
410- )
411- output_bc_patterns = encode_literals (
412- (rv_node .outputs [1 ].type .broadcastable [:batch_ndim ],)
413- )
414- output_dtypes = encode_literals ((rv_node .default_output ().type .dtype ,))
415- inplace_pattern = encode_literals (())
416-
417405 def random_wrapper (core_shape , rng , size , * dist_params ):
418406 if not inplace :
419407 rng = copy (rng )
420408
421- draws = _vectorized (
422- core_op_fn ,
423- input_bc_patterns ,
424- output_bc_patterns ,
425- output_dtypes ,
426- inplace_pattern ,
427- (rng ,),
428- dist_params ,
429- (numba_ndarray .to_fixed_tuple (core_shape , core_shape_len ),),
430- None if size_len is None else numba_ndarray .to_fixed_tuple (size , size_len ),
409+ fixed_size = (
410+ None if size_len is None else numba_ndarray .to_fixed_tuple (size , size_len )
431411 )
432- return rng , draws
412+
413+ with numba .objmode (res = "UniTuple(types.npy_rng, types.pyobject)" ):
414+ # Convert tuple params back to arrays for perform method
415+ np_dist_params = [np .asarray (p ) for p in dist_params ]
416+
417+ # Prepare output storage for perform method
418+ outputs = [[None ], [None ]]
419+
420+ # Call the perform method directly
421+ rv_op .perform (rv_node , [rng , fixed_size , * np_dist_params ], outputs )
422+
423+ next_rng = outputs [0 ][0 ]
424+ result = outputs [1 ][0 ]
425+ res = (next_rng , result )
426+
427+ return res
433428
434429 def random (core_shape , rng , size , * dist_params ):
435430 raise NotImplementedError ("Non-jitted random variable not implemented" )
@@ -439,3 +434,67 @@ def ov_random(core_shape, rng, size, *dist_params):
439434 return random_wrapper
440435
441436 return random
437+
438+
439+ @numba_funcify .register
440+ def numba_funcify_RandomVariable (op : RandomVariableWithCoreShape , node , ** kwargs ):
441+ core_shape = node .inputs [0 ]
442+
443+ [rv_node ] = op .fgraph .apply_nodes
444+ rv_op : RandomVariable = rv_node .op
445+ size = rv_op .size_param (rv_node )
446+ dist_params = rv_op .dist_params (rv_node )
447+ size_len = None if isinstance (size .type , NoneTypeT ) else get_vector_length (size )
448+ core_shape_len = get_vector_length (core_shape )
449+ inplace = rv_op .inplace
450+
451+ try :
452+ core_rv_fn = numba_core_rv_funcify (rv_op , rv_node )
453+ nin = 1 + len (dist_params ) # rng + params
454+ core_op_fn = store_core_outputs (core_rv_fn , nin = nin , nout = 1 )
455+
456+ batch_ndim = rv_op .batch_ndim (rv_node )
457+
458+ # numba doesn't support nested literals right now...
459+ input_bc_patterns = encode_literals (
460+ tuple (
461+ input_var .type .broadcastable [:batch_ndim ] for input_var in dist_params
462+ )
463+ )
464+ output_bc_patterns = encode_literals (
465+ (rv_node .outputs [1 ].type .broadcastable [:batch_ndim ],)
466+ )
467+ output_dtypes = encode_literals ((rv_node .default_output ().type .dtype ,))
468+ inplace_pattern = encode_literals (())
469+
470+ def random_wrapper (core_shape , rng , size , * dist_params ):
471+ if not inplace :
472+ rng = copy (rng )
473+
474+ draws = _vectorized (
475+ core_op_fn ,
476+ input_bc_patterns ,
477+ output_bc_patterns ,
478+ output_dtypes ,
479+ inplace_pattern ,
480+ (rng ,),
481+ dist_params ,
482+ (numba_ndarray .to_fixed_tuple (core_shape , core_shape_len ),),
483+ None
484+ if size_len is None
485+ else numba_ndarray .to_fixed_tuple (size , size_len ),
486+ )
487+ return rng , draws
488+
489+ def random (core_shape , rng , size , * dist_params ):
490+ raise NotImplementedError ("Non-jitted random variable not implemented" )
491+
492+ @overload (random , jit_options = _jit_options )
493+ def ov_random (core_shape , rng , size , * dist_params ):
494+ return random_wrapper
495+
496+ return random
497+
498+ except NotImplementedError :
499+ # Fall back to object mode for random variables that don't have core implementation
500+ return rv_fallback_impl (op , node )
0 commit comments