Skip to content

Commit 7aa37c3

Browse files
fix race conditions and add offline tokenizer loading api
1 parent 8c42597 commit 7aa37c3

File tree

8 files changed

+235
-22
lines changed

8 files changed

+235
-22
lines changed

Cargo.lock

Lines changed: 33 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ wasm-binding = ["wasm-bindgen", "serde-wasm-bindgen", "wasm-bindgen-futures"]
1818
[dependencies]
1919
anyhow = "1.0.98"
2020
base64 = "0.22.1"
21+
fs2 = "0.4.3"
2122
image = "0.25.6"
2223
serde = { version = "1.0.219", features = ["derive"] }
2324
serde_json = { version = "1.0.140", features = ["preserve_order"] }

python/openai_harmony/__init__.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,10 @@
3636
try:
3737
from .openai_harmony import (
3838
HarmonyError as HarmonyError, # expose the actual Rust error directly
39-
)
40-
from .openai_harmony import PyHarmonyEncoding as _PyHarmonyEncoding # type: ignore
41-
from .openai_harmony import (
39+
PyHarmonyEncoding as _PyHarmonyEncoding, # type: ignore
4240
PyStreamableParser as _PyStreamableParser, # type: ignore
43-
)
44-
from .openai_harmony import (
4541
load_harmony_encoding as _load_harmony_encoding, # type: ignore
42+
load_harmony_encoding_from_file as _load_harmony_encoding_from_file, # type: ignore
4643
)
4744

4845
except ModuleNotFoundError: # pragma: no cover – raised during type-checking
@@ -690,6 +687,32 @@ def load_harmony_encoding(name: str | "HarmonyEncodingName") -> HarmonyEncoding:
690687
return HarmonyEncoding(inner)
691688

692689

690+
def load_harmony_encoding_from_file(
691+
name: str,
692+
vocab_file: str,
693+
special_tokens: list[tuple[str, int]],
694+
pattern: str,
695+
n_ctx: int,
696+
max_message_tokens: int,
697+
max_action_length: int,
698+
expected_hash: str | None = None,
699+
) -> HarmonyEncoding:
700+
"""Load a HarmonyEncoding from a local vocab file (offline usage).
701+
Use this when network access is restricted or for reproducible builds where you want to avoid remote downloads.
702+
"""
703+
inner: _PyHarmonyEncoding = _load_harmony_encoding_from_file(
704+
name,
705+
vocab_file,
706+
special_tokens,
707+
pattern,
708+
n_ctx,
709+
max_message_tokens,
710+
max_action_length,
711+
expected_hash,
712+
)
713+
return HarmonyEncoding(inner)
714+
715+
693716
# For *mypy* we expose a minimal stub of the `HarmonyEncodingName` enum. At
694717
# **runtime** the user is expected to pass the *string* names because the Rust
695718
# side only operates on strings anyway.
@@ -718,4 +741,5 @@ def __str__(self) -> str: # noqa: D401
718741
"StreamableParser",
719742
"StreamState",
720743
"HarmonyError",
744+
"load_harmony_encoding_from_file",
721745
]

src/encoding.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,31 @@ impl HarmonyEncoding {
154154
})
155155
.collect()
156156
}
157+
158+
pub fn from_local_file(
159+
name: String,
160+
vocab_file: &std::path::Path,
161+
expected_hash: Option<&str>,
162+
special_tokens: impl IntoIterator<Item = (String, u32)>,
163+
pattern: &str,
164+
n_ctx: usize,
165+
max_message_tokens: usize,
166+
max_action_length: usize,
167+
) -> anyhow::Result<Self> {
168+
use crate::tiktoken_ext::public_encodings::load_encoding_from_file;
169+
let bpe = load_encoding_from_file(vocab_file, expected_hash, special_tokens, pattern)?;
170+
Ok(HarmonyEncoding {
171+
name,
172+
n_ctx,
173+
max_message_tokens,
174+
max_action_length,
175+
tokenizer_name: vocab_file.display().to_string(),
176+
tokenizer: std::sync::Arc::new(bpe),
177+
format_token_mapping: Default::default(),
178+
stop_formatting_tokens: Default::default(),
179+
stop_formatting_tokens_for_assistant_actions: Default::default(),
180+
})
181+
}
157182
}
158183

159184
// Methods for rendering conversations

src/py_module.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,38 @@ fn openai_harmony(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
396396
}
397397
m.add_function(pyo3::wrap_pyfunction!(load_harmony_encoding_py, m)?)?;
398398

399+
// Convenience function to load a HarmonyEncoding from a local vocab file for offline
400+
// scenarios or reproducible builds where remote download is not possible.
401+
#[pyfunction(name = "load_harmony_encoding_from_file")]
402+
fn load_harmony_encoding_from_file_py(
403+
py: Python<'_>,
404+
name: &str,
405+
vocab_file: &str,
406+
special_tokens: Vec<(String, u32)>,
407+
pattern: &str,
408+
n_ctx: usize,
409+
max_message_tokens: usize,
410+
max_action_length: usize,
411+
expected_hash: Option<&str>,
412+
) -> PyResult<Py<PyHarmonyEncoding>> {
413+
let encoding = HarmonyEncoding::from_local_file(
414+
name.to_string(),
415+
std::path::Path::new(vocab_file),
416+
expected_hash,
417+
special_tokens,
418+
pattern,
419+
n_ctx,
420+
max_message_tokens,
421+
max_action_length,
422+
)
423+
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))?;
424+
Py::new(py, PyHarmonyEncoding { inner: encoding })
425+
}
426+
m.add_function(pyo3::wrap_pyfunction!(
427+
load_harmony_encoding_from_file_py,
428+
m
429+
)?)?;
430+
399431
// Convenience functions to get the tool configs for the browser and python tools.
400432
#[pyfunction]
401433
fn get_tool_namespace_config(py: Python<'_>, tool: &str) -> PyResult<PyObject> {

src/tiktoken_ext/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
mod public_encodings;
1+
pub mod public_encodings;
22
pub use public_encodings::{set_tiktoken_base_url, Encoding};

src/tiktoken_ext/public_encodings.rs

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use std::{
99
use base64::{prelude::BASE64_STANDARD, Engine as _};
1010

1111
use crate::tiktoken::{CoreBPE, Rank};
12+
use fs2::FileExt;
1213
use sha1::Sha1;
1314
use sha2::{Digest as _, Sha256};
1415

@@ -420,24 +421,35 @@ fn download_or_find_cached_file(
420421
) -> Result<PathBuf, RemoteVocabFileError> {
421422
let cache_dir = resolve_cache_dir()?;
422423
let cache_path = resolve_cache_path(&cache_dir, url);
423-
if cache_path.exists() {
424-
if verify_file_hash(&cache_path, expected_hash)? {
425-
return Ok(cache_path);
426-
}
427-
let _ = std::fs::remove_file(&cache_path);
428-
}
429-
let hash = load_remote_file(url, &cache_path)?;
430-
if let Some(expected_hash) = expected_hash {
431-
if hash != expected_hash {
424+
let lock_path = cache_path.with_extension("lock");
425+
let lock_file = File::create(&lock_path).map_err(|e| {
426+
RemoteVocabFileError::IOError(format!("creating lock file {lock_path:?}"), e)
427+
})?;
428+
lock_file
429+
.lock_exclusive()
430+
.map_err(|e| RemoteVocabFileError::IOError(format!("locking file {lock_path:?}"), e))?;
431+
let result = (|| {
432+
if cache_path.exists() {
433+
if verify_file_hash(&cache_path, expected_hash)? {
434+
return Ok(cache_path);
435+
}
432436
let _ = std::fs::remove_file(&cache_path);
433-
return Err(RemoteVocabFileError::HashMismatch {
434-
file_url: url.to_string(),
435-
expected_hash: expected_hash.to_string(),
436-
computed_hash: hash,
437-
});
438437
}
439-
}
440-
Ok(cache_path)
438+
let hash = load_remote_file(url, &cache_path)?;
439+
if let Some(expected_hash) = expected_hash {
440+
if hash != expected_hash {
441+
let _ = std::fs::remove_file(&cache_path);
442+
return Err(RemoteVocabFileError::HashMismatch {
443+
file_url: url.to_string(),
444+
expected_hash: expected_hash.to_string(),
445+
computed_hash: hash,
446+
});
447+
}
448+
}
449+
Ok(cache_path)
450+
})();
451+
let _ = fs2::FileExt::unlock(&lock_file);
452+
result
441453
}
442454

443455
#[cfg(target_arch = "wasm32")]
@@ -572,4 +584,25 @@ mod tests {
572584
let _ = encoding.load().unwrap();
573585
}
574586
}
587+
588+
#[test]
589+
fn test_parallel_load_encodings() {
590+
use std::thread;
591+
592+
let encodings = Encoding::all();
593+
for encoding in encodings {
594+
let name = encoding.name();
595+
let handles: Vec<_> = (0..8)
596+
.map(|_| {
597+
let name = name.to_string();
598+
thread::spawn(move || {
599+
Encoding::from_name(&name).unwrap().load().unwrap();
600+
})
601+
})
602+
.collect();
603+
for handle in handles {
604+
handle.join().expect("Thread panicked");
605+
}
606+
}
607+
}
575608
}

tests/test_harmony.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
SystemContent,
3636
ToolDescription,
3737
load_harmony_encoding,
38+
load_harmony_encoding_from_file,
3839
)
3940
from pydantic import ValidationError
4041

@@ -949,3 +950,67 @@ def test_streamable_parser_tool_call_with_constrain_adjacent():
949950
]
950951

951952
assert parser.messages == expected
953+
954+
955+
def test_load_harmony_encoding_from_file(tmp_path):
956+
import os
957+
from openai_harmony import load_harmony_encoding_from_file
958+
959+
cache_dir = os.environ.get("TIKTOKEN_RS_CACHE_DIR")
960+
if not cache_dir:
961+
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "tiktoken-rs-cache")
962+
import hashlib
963+
url = "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken"
964+
cache_key = hashlib.sha1(url.encode()).hexdigest()
965+
vocab_file = os.path.join(cache_dir, cache_key)
966+
if not os.path.exists(vocab_file):
967+
import pytest
968+
pytest.skip("No local vocab file available for offline test")
969+
970+
special_tokens = [
971+
("<|startoftext|>", 199998),
972+
("<|endoftext|>", 199999),
973+
("<|reserved_200000|>", 200000),
974+
("<|reserved_200001|>", 200001),
975+
("<|return|>", 200002),
976+
("<|constrain|>", 200003),
977+
("<|reserved_200004|>", 200004),
978+
("<|channel|>", 200005),
979+
("<|start|>", 200006),
980+
("<|end|>", 200007),
981+
("<|message|>", 200008),
982+
("<|reserved_200009|>", 200009),
983+
("<|reserved_200010|>", 200010),
984+
("<|reserved_200011|>", 200011),
985+
("<|call|>", 200012),
986+
("<|reserved_200013|>", 200013),
987+
]
988+
pattern = "|".join([
989+
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
990+
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
991+
"\\p{N}{1,3}",
992+
" ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
993+
"\\s*[\\r\\n]+",
994+
"\\s+(?!\\S)",
995+
"\\s+",
996+
])
997+
n_ctx = 8192
998+
max_message_tokens = 4096
999+
max_action_length = 256
1000+
expected_hash = "446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d"
1001+
1002+
encoding = load_harmony_encoding_from_file(
1003+
name="test_local",
1004+
vocab_file=vocab_file,
1005+
special_tokens=special_tokens,
1006+
pattern=pattern,
1007+
n_ctx=n_ctx,
1008+
max_message_tokens=max_message_tokens,
1009+
max_action_length=max_action_length,
1010+
expected_hash=expected_hash,
1011+
)
1012+
1013+
text = "Hello world!"
1014+
tokens = encoding.encode(text)
1015+
decoded = encoding.decode(tokens)
1016+
assert decoded.startswith("Hello world")

0 commit comments

Comments
 (0)