Skip to content

Commit 9b059fa

Browse files
committed
feat: allow any object as extra values
1 parent e53862d commit 9b059fa

File tree

7 files changed

+134
-83
lines changed

7 files changed

+134
-83
lines changed

Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ homepage = "https://github.com/ModelTC/mtc-token-healing"
1111
documentation = "https://docs.rs/mtc-token-healing"
1212
authors = ["Chielo Newctle <[email protected]>"]
1313

14+
[workspace.dependencies]
15+
general-sam = { version = "1.0.1", features = ["trie"] }
16+
1417
[package]
1518
name = "mtc-token-healing"
1619
version.workspace = true
@@ -26,7 +29,7 @@ exclude = ["release-plz.toml", ".github", "python"]
2629

2730
[dependencies]
2831
derive_more = { version = "2.0.1", features = ["deref", "as_ref"] }
29-
general-sam = { version = "1.0.1", features = ["trie"] }
32+
general-sam = { workspace = true }
3033
pyo3 = { version = "0.25.0", optional = true }
3134
thiserror = "2.0.12"
3235
tinyvec = { version = "1.9.0", features = ["alloc"] }

python/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ crate-type = ["cdylib"]
1616

1717
[dependencies]
1818
mtc-token-healing = { version = "0.2.2-dev", path = "..", features = ["pyo3"] }
19+
general-sam = { workspace = true }
1920
pyo3 = { version = "0.25.0", features = ["extension-module", "generate-import-lib", "abi3-py310"] }

python/mtc_token_healing/mtc_token_healing.pyi

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Sequence, Tuple, overload
1+
from typing import Generic, Optional, Sequence, Tuple, TypeVar, overload
22

33
TokenId = int
44
SortedTokenId = int
@@ -35,14 +35,16 @@ class VocabPrefixAutomaton:
3535
self, token_ids: Sequence[TokenId]
3636
) -> Sequence[SortedTokenId]: ...
3737

38-
class TokenSeqTrieNode:
38+
Value = TypeVar("Value")
39+
40+
class TokenSeqTrieNode(Generic[Value]):
3941
token: int
40-
pred_range: Optional[SortedTokenRange]
4142
parent: int
4243
subtree_lower: int
4344
subtree_upper: int
45+
depth: int
46+
value: Optional[Value]
4447

4548
def dfs_token_seq_trie(
46-
token_ids_seq: Sequence[Sequence[int]],
47-
pred_rank_ranges: Sequence[SortedTokenRange],
48-
) -> Tuple[Sequence[TokenSeqTrieNode], int]: ...
49+
token_ids_seq_and_values: Sequence[Tuple[Sequence[TokenId], Value]],
50+
) -> Tuple[Sequence[TokenSeqTrieNode[Value]], int]: ...

python/src/lib.rs

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,9 @@
1-
use std::borrow::Cow;
1+
mod prefix_dfs;
22

3-
use ::mtc_token_healing::{
4-
dfs_token_seq_trie, SortedTokenRange, TokenId, TokenSeqInput, TokenSeqTrieNode,
5-
VocabPrefixAutomaton,
6-
};
3+
use ::mtc_token_healing::{SortedTokenRange, TokenId, VocabPrefixAutomaton};
74
use pyo3::prelude::*;
85

9-
#[pyfunction(name = "dfs_token_seq_trie")]
10-
fn dfs_token_seq_trie_py(
11-
token_ids: Vec<Vec<TokenId>>,
12-
pred_rank_ranges: Vec<SortedTokenRange>,
13-
) -> (Vec<TokenSeqTrieNode>, usize) {
14-
let inputs = token_ids
15-
.into_iter()
16-
.zip(pred_rank_ranges)
17-
.map(|(s, r)| TokenSeqInput {
18-
tokens: Cow::Owned(s),
19-
pred_range: r,
20-
})
21-
.collect();
22-
let nodes = dfs_token_seq_trie(inputs);
23-
let parent_chain_len = {
24-
let mut res = 0;
25-
while res < nodes.len() {
26-
let node = &nodes[res];
27-
if node.parent == res.saturating_sub(1) && node.pred_range.is_none() {
28-
res += 1;
29-
continue;
30-
}
31-
break;
32-
}
33-
res
34-
};
35-
(nodes, parent_chain_len)
36-
}
6+
use crate::prefix_dfs::{dfs_token_seq_trie_py, TokenSeqTrieNode};
377

388
#[pymodule]
399
fn mtc_token_healing(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
Lines changed: 81 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,53 @@
11
use std::{borrow::Cow, convert::Infallible};
22

33
use general_sam::{BTreeTransTable, TravelEvent, Trie, TrieNodeAlike};
4+
use pyo3::{pyclass, pyfunction, pymethods, PyObject, PyResult, Python};
45

5-
use crate::{SortedTokenRange, TokenId};
6+
use crate::TokenId;
67

7-
#[derive(Clone, Debug)]
8-
#[cfg_attr(feature = "pyo3", ::pyo3::pyclass(get_all, set_all))]
8+
#[derive(Debug)]
9+
#[pyclass(get_all, set_all)]
910
pub struct TokenSeqTrieNode {
1011
pub parent: usize,
1112
pub subtree_lower: usize,
1213
pub subtree_upper: usize,
14+
pub depth: usize,
1315
pub token: TokenId,
14-
pub pred_range: Option<SortedTokenRange>,
16+
pub value: Option<PyObject>,
1517
}
1618

17-
#[cfg(feature = "pyo3")]
18-
mod _pyo3 {
19-
use pyo3::pymethods;
20-
21-
use super::TokenSeqTrieNode;
22-
23-
#[pymethods]
24-
impl TokenSeqTrieNode {
25-
fn __repr__(&self) -> String {
26-
let Self {
27-
parent,
28-
subtree_lower,
29-
subtree_upper,
30-
token,
31-
pred_range,
32-
} = self;
33-
let pred_range = pred_range
34-
.as_ref()
35-
.map(|r| r.repr_py())
36-
.unwrap_or("None".to_owned());
37-
format!(
38-
"TokenSeqTrieNode(\
19+
#[pymethods]
20+
impl TokenSeqTrieNode {
21+
fn __repr__<'py>(&self, py: Python<'py>) -> PyResult<String> {
22+
let Self {
23+
parent,
24+
subtree_lower,
25+
subtree_upper,
26+
depth,
27+
token,
28+
value,
29+
} = self;
30+
let value = value
31+
.as_ref()
32+
.map(|v| v.call_method0(py, "__repr__")?.extract::<String>(py))
33+
.transpose()?
34+
.unwrap_or("None".to_owned());
35+
Ok(format!(
36+
"TokenSeqTrieNode(\
3937
token={token}, \
40-
pred_range={pred_range}, \
4138
parent={parent}, \
4239
subtree_lower={subtree_lower}, \
43-
subtree_upper={subtree_upper})",
44-
)
45-
}
40+
subtree_upper={subtree_upper}, \
41+
depth={depth}, \
42+
value={value})",
43+
))
4644
}
4745
}
4846

4947
#[derive(Debug)]
5048
pub struct TokenSeqInput<'a> {
5149
pub tokens: Cow<'a, [TokenId]>,
52-
pub pred_range: SortedTokenRange,
50+
pub value: Option<PyObject>,
5351
}
5452

5553
pub fn dfs_token_seq_trie(inputs: Vec<TokenSeqInput>) -> Vec<TokenSeqTrieNode> {
@@ -80,8 +78,9 @@ pub fn dfs_token_seq_trie(inputs: Vec<TokenSeqInput>) -> Vec<TokenSeqTrieNode> {
8078
parent,
8179
subtree_lower: dfs_order_id,
8280
subtree_upper: dfs_order_id,
81+
depth: 0,
8382
token,
84-
pred_range: None,
83+
value: None,
8584
});
8685
}
8786
TravelEvent::Pop(node, _) => {
@@ -99,22 +98,71 @@ pub fn dfs_token_seq_trie(inputs: Vec<TokenSeqInput>) -> Vec<TokenSeqTrieNode> {
9998
Err(e) => match e {},
10099
}
101100

101+
for i in 0..dfs_order.len() {
102+
let parent = dfs_order[i].parent;
103+
if parent == i {
104+
continue;
105+
}
106+
dfs_order[i].depth = dfs_order[parent].depth + 1;
107+
}
108+
102109
for (input, node_id) in inputs.into_iter().zip(seq_last_trie_node_ids) {
103110
if let Some(id) = rank[node_id] {
104-
dfs_order[id].pred_range = Some(input.pred_range);
111+
dfs_order[id].value = input.value;
105112
}
106113
}
107114

108115
#[cfg(debug_assertions)]
109116
for (i, node) in dfs_order.iter().enumerate() {
110117
debug_assert!(node.subtree_lower <= node.subtree_upper);
118+
debug_assert!(node.subtree_lower == i);
119+
debug_assert!(node.parent <= i);
111120
if node.parent < node.subtree_lower {
112121
debug_assert!(node.parent != i);
113122
let parent = &dfs_order[node.parent];
114123
debug_assert!(parent.subtree_lower < node.subtree_lower);
115124
debug_assert!(parent.subtree_upper >= node.subtree_lower);
125+
} else {
126+
debug_assert!(node.parent == i);
116127
}
117128
}
118129

119130
dfs_order
120131
}
132+
133+
#[pyfunction(name = "dfs_token_seq_trie")]
134+
pub fn dfs_token_seq_trie_py<'py>(
135+
py: Python<'py>,
136+
inputs: Vec<(Vec<TokenId>, Option<PyObject>)>,
137+
) -> (Vec<TokenSeqTrieNode>, usize) {
138+
debug_assert!(inputs
139+
.iter()
140+
.all(|(_, o)| o.as_ref().is_none_or(|v| !v.is_none(py))));
141+
142+
py.allow_threads(|| {
143+
let inputs = inputs
144+
.into_iter()
145+
.map(|(s, v)| TokenSeqInput {
146+
tokens: Cow::Owned(s),
147+
value: v,
148+
})
149+
.collect();
150+
151+
let nodes = dfs_token_seq_trie(inputs);
152+
153+
let parent_chain_len = {
154+
let mut res = 0;
155+
while res < nodes.len() {
156+
let node = &nodes[res];
157+
if node.parent == res.saturating_sub(1) && node.value.is_none() {
158+
res += 1;
159+
continue;
160+
}
161+
break;
162+
}
163+
res
164+
};
165+
166+
(nodes, parent_chain_len)
167+
})
168+
}

python/tests/test_dfs_token_seq_trie.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,58 @@ def test_dfs_token_seq_trie():
1616
SortedTokenRange(1, 10),
1717
SortedTokenRange(5, 9),
1818
]
19+
return run_test_dfs_token_seq_trie(tokens_seq, pred_ranges)
1920

20-
nodes = dfs_token_seq_trie(tokens_seq, pred_ranges)
2121

22+
def test_dfs_trie_value_on_prefix_chain():
23+
tokens_seq = [
24+
[3, 9],
25+
[3, 9, 1, 10, 9, 6, 7],
26+
[3, 9, 1, 10, 9, 5],
27+
[3, 9, 1, 10, 2],
28+
[3, 9, 1, 10],
29+
[3, 9, 1, 11],
30+
]
31+
pred_ranges = [
32+
SortedTokenRange(0, 12),
33+
SortedTokenRange(4, 6),
34+
SortedTokenRange(4, 7),
35+
SortedTokenRange(3, 9),
36+
SortedTokenRange(1, 10),
37+
SortedTokenRange(5, 9),
38+
]
39+
return run_test_dfs_token_seq_trie(tokens_seq, pred_ranges)
40+
41+
42+
def run_test_dfs_token_seq_trie(tokens_seq, pred_ranges):
43+
nodes, pre_len = dfs_token_seq_trie(list(zip(tokens_seq, pred_ranges)))
44+
45+
print(f"{pre_len=}")
2246
print([node.token for node in nodes])
47+
print([node.depth for node in nodes])
48+
print([node.subtree_upper for node in nodes])
2349

24-
for i, node in enumerate(nodes):
25-
if node.pred_range is None:
50+
for q, node in enumerate(nodes):
51+
if node.value is None:
2652
continue
2753
seq = []
28-
for j in range(i):
54+
for j in range(q):
2955
if nodes[j].subtree_upper >= node.subtree_upper:
3056
seq.append(nodes[j].token)
3157
seq.append(node.token)
32-
print(seq)
58+
print(seq, node.value)
3359
assert seq in tokens_seq
60+
assert pred_ranges[tokens_seq.index(seq)] == node.value
61+
assert node.depth + 1 == len(seq)
3462

35-
for i in range(len(nodes)):
63+
for q in range(len(nodes)):
3664
masks = [
37-
j <= i and nodes[j].subtree_upper >= nodes[i].subtree_upper
38-
for j in range(len(nodes))
65+
k <= q and nodes[k].subtree_upper >= nodes[q].subtree_upper
66+
for k in range(len(nodes))
3967
]
4068
print("".join(map(str, map(int, masks))))
4169

4270

4371
if __name__ == "__main__":
4472
test_dfs_token_seq_trie()
73+
test_dfs_trie_value_on_prefix_chain()

src/lib.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@
1111
//! is the same as walking to the state on the suffix automaton
1212
//! and gathering information among the subtree of the link tree.
1313
mod automaton;
14-
mod prefix_dfs;
1514
mod token;
1615

1716
pub use crate::{
1817
automaton::VocabPrefixAutomaton,
19-
prefix_dfs::{TokenSeqInput, TokenSeqTrieNode, dfs_token_seq_trie},
2018
token::{SmallToken, SortedTokenId, SortedTokenRange, TokenId},
2119
};
2220

0 commit comments

Comments
 (0)