11//! Provides tools and interfaces to integrate the crate's functionality with Python.
22
3+ use std:: collections:: VecDeque ;
34use std:: sync:: Arc ;
45
56use bincode:: { config, Decode , Encode } ;
@@ -27,16 +28,19 @@ macro_rules! type_name {
2728pub struct PyGuide {
2829 state : StateId ,
2930 index : PyIndex ,
31+ state_cache : VecDeque < StateId > ,
3032}
3133
3234#[ pymethods]
3335impl PyGuide {
3436 /// Creates a Guide object based on Index.
3537 #[ new]
36- fn __new__ ( index : PyIndex ) -> Self {
38+ #[ pyo3( signature = ( index, max_rollback=32 ) ) ]
39+ fn __new__ ( index : PyIndex , max_rollback : usize ) -> Self {
3740 PyGuide {
3841 state : index. get_initial_state ( ) ,
3942 index,
43+ state_cache : VecDeque :: with_capacity ( max_rollback) ,
4044 }
4145 }
4246
@@ -57,6 +61,11 @@ impl PyGuide {
5761 ) ) )
5862 }
5963
64+ /// Get the number of rollback steps available.
65+ fn get_allowed_rollback ( & self ) -> usize {
66+ self . state_cache . len ( )
67+ }
68+
6069 /// Guide moves to the next state provided by the token id and returns a list of allowed tokens, unless return_tokens is False.
6170 #[ pyo3( signature = ( token_id, return_tokens=None ) ) ]
6271 fn advance (
@@ -66,6 +75,11 @@ impl PyGuide {
6675 ) -> PyResult < Option < Vec < TokenId > > > {
6776 match self . index . get_next_state ( self . state , token_id) {
6877 Some ( new_state) => {
78+ // Free up space in state_cache if needed.
79+ if self . state_cache . len ( ) == self . state_cache . capacity ( ) {
80+ self . state_cache . pop_front ( ) ;
81+ }
82+ self . state_cache . push_back ( self . state ) ;
6983 self . state = new_state;
7084 if return_tokens. unwrap_or ( true ) {
7185 self . get_tokens ( ) . map ( Some )
@@ -80,6 +94,41 @@ impl PyGuide {
8094 }
8195 }
8296
97+ /// Rollback the Guide state `n` tokens (states).
98+ /// Fails if `n` is greater than stored prior states.
99+ fn rollback_state ( & mut self , n : usize ) -> PyResult < ( ) > {
100+ if n == 0 {
101+ return Ok ( ( ) ) ;
102+ }
103+ if n > self . get_allowed_rollback ( ) {
104+ return Err ( PyValueError :: new_err ( format ! (
105+ "Cannot roll back {n} step(s): only {available} states stored (max_rollback = {cap}). \
106+ You must advance through at least {n} state(s) before rolling back {n} step(s).",
107+ cap = self . state_cache. capacity( ) ,
108+ available = self . get_allowed_rollback( ) ,
109+ ) ) ) ;
110+ }
111+ let mut new_state: u32 = self . state ;
112+ for _ in 0 ..n {
113+ // unwrap is safe because length is checked above
114+ new_state = self . state_cache . pop_back ( ) . unwrap ( ) ;
115+ }
116+ self . state = new_state;
117+ Ok ( ( ) )
118+ }
119+
120+ // Returns a boolean indicating if the sequence leads to a valid state in the DFA
121+ fn accepts_tokens ( & self , sequence : Vec < u32 > ) -> bool {
122+ let mut state = self . state ;
123+ for t in sequence {
124+ match self . index . get_next_state ( state, t) {
125+ Some ( s) => state = s,
126+ None => return false ,
127+ }
128+ }
129+ true
130+ }
131+
83132 /// Checks if the automaton is in a final state.
84133 fn is_finished ( & self ) -> bool {
85134 self . index . is_final_state ( self . state )
0 commit comments