Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 144 additions & 4 deletions mlx-rs/src/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,125 @@ thread_local! {
static TASK_LOCAL_STATE: RefCell<Option<RandomState>> = const { RefCell::new(None) };
}

/// Random state factory
/// Random state for reproducible random number generation.
///
/// This struct holds the PRNG state and can be used with compiled functions
/// to properly track random state across JIT compilation boundaries.
///
/// # Compilation Support
///
/// `RandomState` implements `Updatable`, making it compatible with
/// `compile_with_state`. This is the Rust equivalent of Python's
/// `@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)`.
///
/// # Example
///
/// ```rust,no_run
/// use mlx_rs::random::RandomState;
/// use mlx_rs::transforms::compile::compile_with_state;
/// use mlx_rs::random::categorical;
/// use mlx_rs::Array;
///
/// let mut state = RandomState::with_seed(42).unwrap();
/// let logits = Array::zeros::<f32>(&[1, 10]).unwrap();
/// let mut compiled = compile_with_state(
/// |state: &mut RandomState, x: &Array| {
/// let key = state.next_key()?;
/// categorical(x, None, None, Some(&key))
/// },
/// None
/// );
/// let result = compiled(&mut state, &logits).unwrap();
/// ```
#[derive(Debug, Clone)]
pub struct RandomState {
state: Array,
}

impl RandomState {
fn new() -> Result<Self> {
/// Create a new random state with a time-based seed.
pub fn new() -> Result<Self> {
let now = unsafe { mach_time::mach_approximate_time() };
Ok(Self { state: key(now)? })
}

fn next(&mut self) -> Result<Array> {
/// Create a new random state from a specific seed.
///
/// Use this for reproducible random number generation.
pub fn with_seed(seed: u64) -> Result<Self> {
Ok(Self { state: key(seed)? })
}

/// Create a random state from an existing key array.
///
/// The key must be a valid PRNG key (typically created via `random::key()`).
pub fn from_key(key: Array) -> Self {
Self { state: key }
}

/// Get the next random key, advancing the state.
///
/// This splits the current state into two keys: one becomes the new state,
/// and the other is returned for use in random operations.
pub fn next_key(&mut self) -> Result<Array> {
let next = split(&self.state, 2)?;
self.state = next.0;
Ok(next.1)
}

fn seed(&mut self, seed: u64) -> Result<()> {
/// Internal method for backward compatibility.
fn next(&mut self) -> Result<Array> {
self.next_key()
}

/// Reseed the random state.
pub fn seed(&mut self, seed: u64) -> Result<()> {
self.state = key(seed)?;
Ok(())
}

/// Get a reference to the underlying state array.
///
/// This is useful for inspection or manual state management.
pub fn as_array(&self) -> &Array {
&self.state
}

/// Get a mutable reference to the underlying state array.
///
/// # Note
///
/// Modifying the state array directly may break the PRNG invariants.
/// Prefer using `seed()` or `next_key()` instead.
pub fn as_array_mut(&mut self) -> &mut Array {
&mut self.state
}
}

impl Default for RandomState {
/// Creates a new `RandomState` with a time-based seed.
///
/// # Panics
///
/// Panics if the underlying PRNG key creation fails, which should not
/// occur under normal conditions.
fn default() -> Self {
Self::new().expect("Failed to create default RandomState")
}
}

impl crate::utils::Updatable for RandomState {
fn updatable_states_len(&self) -> usize {
1
}

fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
std::iter::once(&self.state)
}

fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
std::iter::once(&mut self.state)
}
}

fn global_state() -> &'static Mutex<RandomState> {
Expand Down Expand Up @@ -717,6 +814,49 @@ mod tests {
assert_array_eq!(result, expected, 0.01);
}

#[test]
fn test_random_state_new() {
let state = RandomState::new().unwrap();
assert_eq!(state.as_array().shape(), &[2]);
}

#[test]
fn test_random_state_with_seed_deterministic() {
let s1 = RandomState::with_seed(42).unwrap();
let s2 = RandomState::with_seed(42).unwrap();
assert!(s1.as_array() == s2.as_array());
}

#[test]
fn test_random_state_next_key_advances() {
let mut state = RandomState::with_seed(0).unwrap();
let k1 = state.next_key().unwrap();
let k2 = state.next_key().unwrap();
assert!(k1 != k2);
}

#[test]
fn test_random_state_from_key_roundtrip() {
let original = RandomState::with_seed(99).unwrap();
let arr = original.as_array().clone();
let restored = RandomState::from_key(arr);
assert!(original.as_array() == restored.as_array());
}

#[test]
fn test_random_state_updatable() {
use crate::utils::Updatable;
let state = RandomState::with_seed(0).unwrap();
assert_eq!(state.updatable_states_len(), 1);
assert_eq!(state.updatable_states().into_iter().count(), 1);
}

#[test]
fn test_random_state_default() {
let state = RandomState::default();
assert_eq!(state.as_array().shape(), &[2]);
}

#[test]
fn test_random_seed_same() {
// Same random seed should produce the same results
Expand Down
Loading