diff --git a/distrax/_src/distributions/distribution.py b/distrax/_src/distributions/distribution.py index ec33f6d..0e501d7 100644 --- a/distrax/_src/distributions/distribution.py +++ b/distrax/_src/distributions/distribution.py @@ -351,7 +351,7 @@ def to_batch_shape_index( A new index that is only applied on the batch shape. """ try: - new_index = [x[index] for x in np.indices(batch_shape)] + new_index = [x[index] for x in jnp.indices(batch_shape)] return tuple(new_index) except IndexError as e: raise IndexError(f'Batch shape `{batch_shape}` not compatible with index '