11use std:: { borrow:: Cow , convert:: Infallible } ;
22
33use general_sam:: { BTreeTransTable , TravelEvent , Trie , TrieNodeAlike } ;
4+ use pyo3:: { pyclass, pyfunction, pymethods, PyObject , PyResult , Python } ;
45
5- use crate :: { SortedTokenRange , TokenId } ;
6+ use crate :: TokenId ;
67
7- #[ derive( Clone , Debug ) ]
8- #[ cfg_attr ( feature = "pyo3" , :: pyo3 :: pyclass( get_all, set_all) ) ]
8+ #[ derive( Debug ) ]
9+ #[ pyclass( get_all, set_all) ]
910pub struct TokenSeqTrieNode {
1011 pub parent : usize ,
1112 pub subtree_lower : usize ,
1213 pub subtree_upper : usize ,
14+ pub depth : usize ,
1315 pub token : TokenId ,
14- pub pred_range : Option < SortedTokenRange > ,
16+ pub value : Option < PyObject > ,
1517}
1618
17- #[ cfg( feature = "pyo3" ) ]
18- mod _pyo3 {
19- use pyo3:: pymethods;
20-
21- use super :: TokenSeqTrieNode ;
22-
23- #[ pymethods]
24- impl TokenSeqTrieNode {
25- fn __repr__ ( & self ) -> String {
26- let Self {
27- parent,
28- subtree_lower,
29- subtree_upper,
30- token,
31- pred_range,
32- } = self ;
33- let pred_range = pred_range
34- . as_ref ( )
35- . map ( |r| r. repr_py ( ) )
36- . unwrap_or ( "None" . to_owned ( ) ) ;
37- format ! (
38- "TokenSeqTrieNode(\
19+ #[ pymethods]
20+ impl TokenSeqTrieNode {
21+ fn __repr__ < ' py > ( & self , py : Python < ' py > ) -> PyResult < String > {
22+ let Self {
23+ parent,
24+ subtree_lower,
25+ subtree_upper,
26+ depth,
27+ token,
28+ value,
29+ } = self ;
30+ let value = value
31+ . as_ref ( )
32+ . map ( |v| v. call_method0 ( py, "__repr__" ) ?. extract :: < String > ( py) )
33+ . transpose ( ) ?
34+ . unwrap_or ( "None" . to_owned ( ) ) ;
35+ Ok ( format ! (
36+ "TokenSeqTrieNode(\
3937 token={token}, \
40- pred_range={pred_range}, \
4138 parent={parent}, \
4239 subtree_lower={subtree_lower}, \
43- subtree_upper={subtree_upper})",
44- )
45- }
40+ subtree_upper={subtree_upper}, \
41+ depth={depth}, \
42+ value={value})",
43+ ) )
4644 }
4745}
4846
4947#[ derive( Debug ) ]
5048pub struct TokenSeqInput < ' a > {
5149 pub tokens : Cow < ' a , [ TokenId ] > ,
52- pub pred_range : SortedTokenRange ,
50+ pub value : Option < PyObject > ,
5351}
5452
5553pub fn dfs_token_seq_trie ( inputs : Vec < TokenSeqInput > ) -> Vec < TokenSeqTrieNode > {
@@ -80,8 +78,9 @@ pub fn dfs_token_seq_trie(inputs: Vec<TokenSeqInput>) -> Vec<TokenSeqTrieNode> {
8078 parent,
8179 subtree_lower : dfs_order_id,
8280 subtree_upper : dfs_order_id,
81+ depth : 0 ,
8382 token,
84- pred_range : None ,
83+ value : None ,
8584 } ) ;
8685 }
8786 TravelEvent :: Pop ( node, _) => {
@@ -99,22 +98,71 @@ pub fn dfs_token_seq_trie(inputs: Vec<TokenSeqInput>) -> Vec<TokenSeqTrieNode> {
9998 Err ( e) => match e { } ,
10099 }
101100
101+ for i in 0 ..dfs_order. len ( ) {
102+ let parent = dfs_order[ i] . parent ;
103+ if parent == i {
104+ continue ;
105+ }
106+ dfs_order[ i] . depth = dfs_order[ parent] . depth + 1 ;
107+ }
108+
102109 for ( input, node_id) in inputs. into_iter ( ) . zip ( seq_last_trie_node_ids) {
103110 if let Some ( id) = rank[ node_id] {
104- dfs_order[ id] . pred_range = Some ( input. pred_range ) ;
111+ dfs_order[ id] . value = input. value ;
105112 }
106113 }
107114
108115 #[ cfg( debug_assertions) ]
109116 for ( i, node) in dfs_order. iter ( ) . enumerate ( ) {
110117 debug_assert ! ( node. subtree_lower <= node. subtree_upper) ;
118+ debug_assert ! ( node. subtree_lower == i) ;
119+ debug_assert ! ( node. parent <= i) ;
111120 if node. parent < node. subtree_lower {
112121 debug_assert ! ( node. parent != i) ;
113122 let parent = & dfs_order[ node. parent ] ;
114123 debug_assert ! ( parent. subtree_lower < node. subtree_lower) ;
115124 debug_assert ! ( parent. subtree_upper >= node. subtree_lower) ;
125+ } else {
126+ debug_assert ! ( node. parent == i) ;
116127 }
117128 }
118129
119130 dfs_order
120131}
132+
133+ #[ pyfunction( name = "dfs_token_seq_trie" ) ]
134+ pub fn dfs_token_seq_trie_py < ' py > (
135+ py : Python < ' py > ,
136+ inputs : Vec < ( Vec < TokenId > , Option < PyObject > ) > ,
137+ ) -> ( Vec < TokenSeqTrieNode > , usize ) {
138+ debug_assert ! ( inputs
139+ . iter( )
140+ . all( |( _, o) | o. as_ref( ) . is_none_or( |v| !v. is_none( py) ) ) ) ;
141+
142+ py. allow_threads ( || {
143+ let inputs = inputs
144+ . into_iter ( )
145+ . map ( |( s, v) | TokenSeqInput {
146+ tokens : Cow :: Owned ( s) ,
147+ value : v,
148+ } )
149+ . collect ( ) ;
150+
151+ let nodes = dfs_token_seq_trie ( inputs) ;
152+
153+ let parent_chain_len = {
154+ let mut res = 0 ;
155+ while res < nodes. len ( ) {
156+ let node = & nodes[ res] ;
157+ if node. parent == res. saturating_sub ( 1 ) && node. value . is_none ( ) {
158+ res += 1 ;
159+ continue ;
160+ }
161+ break ;
162+ }
163+ res
164+ } ;
165+
166+ ( nodes, parent_chain_len)
167+ } )
168+ }
0 commit comments