Skip to content

Commit 5f729a3

Browse files
twieckiclaude
andcommitted
Add object mode fallback for Numba RandomVariables
When we find a RandomVariable that doesn't have a Numba implementation, we now fallback to object mode instead of failing with NotImplementedError. This provides a more graceful degradation path for random variables that don't yet have specialized Numba implementations. - Added rv_fallback_impl function to create object mode implementation - Modified numba_funcify_RandomVariable to catch NotImplementedError - Added test for unsupported random variable fallback 🤖 Generated with Claude Code Co-Authored-By: Claude <[email protected]>
1 parent 7f03125 commit 5f729a3

File tree

2 files changed

+140
-32
lines changed

2 files changed

+140
-32
lines changed

pytensor/link/numba/dispatch/random.py

Lines changed: 91 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -386,50 +386,45 @@ def numba_funcify_RandomVariable_core(op: RandomVariable, **kwargs):
386386
)
387387

388388

389-
@numba_funcify.register
390-
def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs):
391-
core_shape = node.inputs[0]
389+
def rv_fallback_impl(op: RandomVariableWithCoreShape, node):
390+
"""Create a fallback implementation for random variables using object mode."""
391+
import warnings
392392

393393
[rv_node] = op.fgraph.apply_nodes
394394
rv_op: RandomVariable = rv_node.op
395+
396+
warnings.warn(
397+
f"Numba will use object mode to execute the random variable {rv_op.name}",
398+
UserWarning,
399+
)
400+
395401
size = rv_op.size_param(rv_node)
396-
dist_params = rv_op.dist_params(rv_node)
397402
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)
398-
core_shape_len = get_vector_length(core_shape)
399403
inplace = rv_op.inplace
400404

401-
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
402-
nin = 1 + len(dist_params) # rng + params
403-
core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1)
404-
405-
batch_ndim = rv_op.batch_ndim(rv_node)
406-
407-
# numba doesn't support nested literals right now...
408-
input_bc_patterns = encode_literals(
409-
tuple(input_var.type.broadcastable[:batch_ndim] for input_var in dist_params)
410-
)
411-
output_bc_patterns = encode_literals(
412-
(rv_node.outputs[1].type.broadcastable[:batch_ndim],)
413-
)
414-
output_dtypes = encode_literals((rv_node.default_output().type.dtype,))
415-
inplace_pattern = encode_literals(())
416-
417405
def random_wrapper(core_shape, rng, size, *dist_params):
418406
if not inplace:
419407
rng = copy(rng)
420408

421-
draws = _vectorized(
422-
core_op_fn,
423-
input_bc_patterns,
424-
output_bc_patterns,
425-
output_dtypes,
426-
inplace_pattern,
427-
(rng,),
428-
dist_params,
429-
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),),
430-
None if size_len is None else numba_ndarray.to_fixed_tuple(size, size_len),
409+
fixed_size = (
410+
None if size_len is None else numba_ndarray.to_fixed_tuple(size, size_len)
431411
)
432-
return rng, draws
412+
413+
with numba.objmode(res="UniTuple(types.npy_rng, types.pyobject)"):
414+
# Convert tuple params back to arrays for perform method
415+
np_dist_params = [np.asarray(p) for p in dist_params]
416+
417+
# Prepare output storage for perform method
418+
outputs = [[None], [None]]
419+
420+
# Call the perform method directly
421+
rv_op.perform(rv_node, [rng, fixed_size, *np_dist_params], outputs)
422+
423+
next_rng = outputs[0][0]
424+
result = outputs[1][0]
425+
res = (next_rng, result)
426+
427+
return res
433428

434429
def random(core_shape, rng, size, *dist_params):
435430
raise NotImplementedError("Non-jitted random variable not implemented")
@@ -439,3 +434,67 @@ def ov_random(core_shape, rng, size, *dist_params):
439434
return random_wrapper
440435

441436
return random
437+
438+
439+
@numba_funcify.register
440+
def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs):
441+
core_shape = node.inputs[0]
442+
443+
[rv_node] = op.fgraph.apply_nodes
444+
rv_op: RandomVariable = rv_node.op
445+
size = rv_op.size_param(rv_node)
446+
dist_params = rv_op.dist_params(rv_node)
447+
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)
448+
core_shape_len = get_vector_length(core_shape)
449+
inplace = rv_op.inplace
450+
451+
try:
452+
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
453+
nin = 1 + len(dist_params) # rng + params
454+
core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1)
455+
456+
batch_ndim = rv_op.batch_ndim(rv_node)
457+
458+
# numba doesn't support nested literals right now...
459+
input_bc_patterns = encode_literals(
460+
tuple(
461+
input_var.type.broadcastable[:batch_ndim] for input_var in dist_params
462+
)
463+
)
464+
output_bc_patterns = encode_literals(
465+
(rv_node.outputs[1].type.broadcastable[:batch_ndim],)
466+
)
467+
output_dtypes = encode_literals((rv_node.default_output().type.dtype,))
468+
inplace_pattern = encode_literals(())
469+
470+
def random_wrapper(core_shape, rng, size, *dist_params):
471+
if not inplace:
472+
rng = copy(rng)
473+
474+
draws = _vectorized(
475+
core_op_fn,
476+
input_bc_patterns,
477+
output_bc_patterns,
478+
output_dtypes,
479+
inplace_pattern,
480+
(rng,),
481+
dist_params,
482+
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),),
483+
None
484+
if size_len is None
485+
else numba_ndarray.to_fixed_tuple(size, size_len),
486+
)
487+
return rng, draws
488+
489+
def random(core_shape, rng, size, *dist_params):
490+
raise NotImplementedError("Non-jitted random variable not implemented")
491+
492+
@overload(random, jit_options=_jit_options)
493+
def ov_random(core_shape, rng, size, *dist_params):
494+
return random_wrapper
495+
496+
return random
497+
498+
except NotImplementedError:
499+
# Fall back to object mode for random variables that don't have core implementation
500+
return rv_fallback_impl(op, node)

tests/link/numba/test_random.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,3 +705,52 @@ def test_repeated_args():
705705
final_node = fn.maker.fgraph.outputs[0].owner
706706
assert isinstance(final_node.op, RandomVariableWithCoreShape)
707707
assert final_node.inputs[-2] is final_node.inputs[-1]
708+
709+
710+
def test_unsupported_rv_fallback():
711+
"""Test that unsupported random variables fallback to object mode."""
712+
import warnings
713+
714+
# Create a mock random variable that doesn't have a numba implementation
715+
class CustomRV(ptr.RandomVariable):
716+
name = "custom"
717+
signature = "(d)->(d)" # We need a parameter for test to pass
718+
dtype = "float64"
719+
720+
def _supp_shape_from_params(self, dist_params, param_shapes=None):
721+
# Return the shape of the support
722+
return [1]
723+
724+
def rng_fn(self, rng, value, size=None):
725+
# Just return the value plus a random number
726+
return value + rng.standard_normal()
727+
728+
custom_rv = CustomRV()
729+
730+
# Create a graph with the unsupported RV
731+
rng = shared(np.random.default_rng(123))
732+
value = np.array(1.0)
733+
x = custom_rv(value, rng=rng)
734+
735+
# Capture warnings to check for the fallback warning
736+
with warnings.catch_warnings(record=True) as w:
737+
warnings.simplefilter("always")
738+
739+
# Compile with numba mode
740+
fn = function([], x, mode=numba_mode)
741+
742+
# Execute to trigger the fallback
743+
result = fn()
744+
745+
# Check that a warning was raised about object mode
746+
assert any("object mode" in str(warning.message) for warning in w)
747+
748+
# Verify the result is as expected
749+
assert isinstance(result, np.ndarray)
750+
751+
# Run again to make sure the compiled function works properly
752+
result2 = fn()
753+
assert isinstance(result2, np.ndarray)
754+
assert not np.array_equal(
755+
result, result2
756+
) # Results should differ with different RNG states

0 commit comments

Comments
 (0)