@@ -17,28 +17,116 @@ thread_local! {
1717 static TASK_LOCAL_STATE : RefCell <Option <RandomState >> = const { RefCell :: new( None ) } ;
1818}
1919
20- /// Random state factory
20+ /// Random state for reproducible random number generation.
21+ ///
22+ /// This struct holds the PRNG state and can be used with compiled functions
23+ /// to properly track random state across JIT compilation boundaries.
24+ ///
25+ /// # Compilation Support
26+ ///
27+ /// `RandomState` implements `Updatable`, making it compatible with
28+ /// `compile_with_state`. This is the Rust equivalent of Python's
29+ /// `@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)`.
30+ ///
31+ /// # Example
32+ ///
33+ /// ```rust,ignore
34+ /// use mlx_rs::random::RandomState;
35+ /// use mlx_rs::transforms::compile::compile_with_state;
36+ ///
37+ /// let mut state = RandomState::with_seed(42)?;
38+ /// let mut compiled = compile_with_state(
39+ /// |state: &mut RandomState, x: &Array| {
40+ /// let key = state.next_key()?;
41+ /// categorical(x, None, None, Some(&key))
42+ /// },
43+ /// None
44+ /// );
45+ /// let result = compiled(&mut state, &logits)?;
46+ /// ```
2147#[ derive( Debug , Clone ) ]
2248pub struct RandomState {
2349 state : Array ,
2450}
2551
2652impl RandomState {
27- fn new ( ) -> Result < Self > {
53+ /// Create a new random state with a time-based seed.
54+ pub fn new ( ) -> Result < Self > {
2855 let now = unsafe { mach_time:: mach_approximate_time ( ) } ;
2956 Ok ( Self { state : key ( now) ? } )
3057 }
3158
32- fn next ( & mut self ) -> Result < Array > {
59+ /// Create a new random state from a specific seed.
60+ ///
61+ /// Use this for reproducible random number generation.
62+ pub fn with_seed ( seed : u64 ) -> Result < Self > {
63+ Ok ( Self { state : key ( seed) ? } )
64+ }
65+
66+ /// Create a random state from an existing key array.
67+ ///
68+ /// The key must be a valid PRNG key (typically created via `random::key()`).
69+ pub fn from_key ( key : Array ) -> Self {
70+ Self { state : key }
71+ }
72+
73+ /// Get the next random key, advancing the state.
74+ ///
75+ /// This splits the current state into two keys: one becomes the new state,
76+ /// and the other is returned for use in random operations.
77+ pub fn next_key ( & mut self ) -> Result < Array > {
3378 let next = split ( & self . state , 2 ) ?;
3479 self . state = next. 0 ;
3580 Ok ( next. 1 )
3681 }
3782
38- fn seed ( & mut self , seed : u64 ) -> Result < ( ) > {
83+ /// Internal method for backward compatibility.
84+ fn next ( & mut self ) -> Result < Array > {
85+ self . next_key ( )
86+ }
87+
88+ /// Reseed the random state.
89+ pub fn seed ( & mut self , seed : u64 ) -> Result < ( ) > {
3990 self . state = key ( seed) ?;
4091 Ok ( ( ) )
4192 }
93+
94+ /// Get a reference to the underlying state array.
95+ ///
96+ /// This is useful for inspection or manual state management.
97+ pub fn as_array ( & self ) -> & Array {
98+ & self . state
99+ }
100+
101+ /// Get a mutable reference to the underlying state array.
102+ ///
103+ /// # Safety
104+ ///
105+ /// Modifying the state array directly may break the PRNG invariants.
106+ /// Prefer using `seed()` or `next_key()` instead.
107+ pub fn as_array_mut ( & mut self ) -> & mut Array {
108+ & mut self . state
109+ }
110+ }
111+
112+ impl Default for RandomState {
113+ fn default ( ) -> Self {
114+ Self :: new ( ) . expect ( "Failed to create default RandomState" )
115+ }
116+ }
117+
118+ impl crate :: utils:: Updatable for RandomState {
119+ fn updatable_states_len ( & self ) -> usize {
120+ 1
121+ }
122+
123+ fn updatable_states ( & self ) -> impl IntoIterator < Item = & Array > {
124+ std:: iter:: once ( & self . state )
125+ }
126+
127+ fn updatable_states_mut ( & mut self ) -> impl IntoIterator < Item = & mut Array > {
128+ std:: iter:: once ( & mut self . state )
129+ }
42130}
43131
44132fn global_state ( ) -> & ' static Mutex < RandomState > {
0 commit comments