Skip to content

Commit d8557d6

Browse files
feat: public RandomState API + Updatable impl for compiled sampling
- Make RandomState methods public (new, with_seed, next_key, seed) - Implement Updatable trait for compile_with_state compatibility - Add Default impl for RandomState - Add helper methods (from_key, as_array, as_array_mut)
1 parent af21d79 commit d8557d6

File tree

1 file changed

+92
-4
lines changed

1 file changed

+92
-4
lines changed

mlx-rs/src/random.rs

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,116 @@ thread_local! {
1717
static TASK_LOCAL_STATE: RefCell<Option<RandomState>> = const { RefCell::new(None) };
1818
}
1919

20-
/// Random state factory
20+
/// Random state for reproducible random number generation.
21+
///
22+
/// This struct holds the PRNG state and can be used with compiled functions
23+
/// to properly track random state across JIT compilation boundaries.
24+
///
25+
/// # Compilation Support
26+
///
27+
/// `RandomState` implements `Updatable`, making it compatible with
28+
/// `compile_with_state`. This is the Rust equivalent of Python's
29+
/// `@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)`.
30+
///
31+
/// # Example
32+
///
33+
/// ```rust,ignore
34+
/// use mlx_rs::random::RandomState;
35+
/// use mlx_rs::transforms::compile::compile_with_state;
36+
///
37+
/// let mut state = RandomState::with_seed(42)?;
38+
/// let mut compiled = compile_with_state(
39+
/// |state: &mut RandomState, x: &Array| {
40+
/// let key = state.next_key()?;
41+
/// categorical(x, None, None, Some(&key))
42+
/// },
43+
/// None
44+
/// );
45+
/// let result = compiled(&mut state, &logits)?;
46+
/// ```
2147
#[derive(Debug, Clone)]
2248
pub struct RandomState {
2349
state: Array,
2450
}
2551

2652
impl RandomState {
27-
fn new() -> Result<Self> {
53+
/// Create a new random state with a time-based seed.
54+
pub fn new() -> Result<Self> {
2855
let now = unsafe { mach_time::mach_approximate_time() };
2956
Ok(Self { state: key(now)? })
3057
}
3158

32-
fn next(&mut self) -> Result<Array> {
59+
/// Create a new random state from a specific seed.
60+
///
61+
/// Use this for reproducible random number generation.
62+
pub fn with_seed(seed: u64) -> Result<Self> {
63+
Ok(Self { state: key(seed)? })
64+
}
65+
66+
/// Create a random state from an existing key array.
67+
///
68+
/// The key must be a valid PRNG key (typically created via `random::key()`).
69+
pub fn from_key(key: Array) -> Self {
70+
Self { state: key }
71+
}
72+
73+
/// Get the next random key, advancing the state.
74+
///
75+
/// This splits the current state into two keys: one becomes the new state,
76+
/// and the other is returned for use in random operations.
77+
pub fn next_key(&mut self) -> Result<Array> {
3378
let next = split(&self.state, 2)?;
3479
self.state = next.0;
3580
Ok(next.1)
3681
}
3782

38-
fn seed(&mut self, seed: u64) -> Result<()> {
83+
/// Internal method for backward compatibility.
84+
fn next(&mut self) -> Result<Array> {
85+
self.next_key()
86+
}
87+
88+
/// Reseed the random state.
89+
pub fn seed(&mut self, seed: u64) -> Result<()> {
3990
self.state = key(seed)?;
4091
Ok(())
4192
}
93+
94+
/// Get a reference to the underlying state array.
95+
///
96+
/// This is useful for inspection or manual state management.
97+
pub fn as_array(&self) -> &Array {
98+
&self.state
99+
}
100+
101+
/// Get a mutable reference to the underlying state array.
102+
///
103+
/// # Safety
104+
///
105+
/// Modifying the state array directly may break the PRNG invariants.
106+
/// Prefer using `seed()` or `next_key()` instead.
107+
pub fn as_array_mut(&mut self) -> &mut Array {
108+
&mut self.state
109+
}
110+
}
111+
112+
impl Default for RandomState {
113+
fn default() -> Self {
114+
Self::new().expect("Failed to create default RandomState")
115+
}
116+
}
117+
118+
impl crate::utils::Updatable for RandomState {
119+
fn updatable_states_len(&self) -> usize {
120+
1
121+
}
122+
123+
fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
124+
std::iter::once(&self.state)
125+
}
126+
127+
fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
128+
std::iter::once(&mut self.state)
129+
}
42130
}
43131

44132
fn global_state() -> &'static Mutex<RandomState> {

0 commit comments

Comments
 (0)