Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ wasm-binding = ["wasm-bindgen", "serde-wasm-bindgen", "wasm-bindgen-futures"]
[dependencies]
anyhow = "1.0.98"
base64 = "0.22.1"
fs2 = "0.4.3"
image = "0.25.6"
serde = { version = "1.0.219", features = ["derive"] }
serde_json = { version = "1.0.140", features = ["preserve_order"] }
Expand Down
34 changes: 29 additions & 5 deletions python/openai_harmony/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,10 @@
try:
from .openai_harmony import (
HarmonyError as HarmonyError, # expose the actual Rust error directly
)
from .openai_harmony import PyHarmonyEncoding as _PyHarmonyEncoding # type: ignore
from .openai_harmony import (
PyHarmonyEncoding as _PyHarmonyEncoding, # type: ignore
PyStreamableParser as _PyStreamableParser, # type: ignore
)
from .openai_harmony import (
load_harmony_encoding as _load_harmony_encoding, # type: ignore
load_harmony_encoding_from_file as _load_harmony_encoding_from_file, # type: ignore
)

except ModuleNotFoundError: # pragma: no cover – raised during type-checking
Expand Down Expand Up @@ -690,6 +687,32 @@ def load_harmony_encoding(name: str | "HarmonyEncodingName") -> HarmonyEncoding:
return HarmonyEncoding(inner)


def load_harmony_encoding_from_file(
name: str,
vocab_file: str,
special_tokens: list[tuple[str, int]],
pattern: str,
n_ctx: int,
max_message_tokens: int,
max_action_length: int,
expected_hash: str | None = None,
) -> HarmonyEncoding:
"""Load a HarmonyEncoding from a local vocab file (offline usage).
Use this when network access is restricted or for reproducible builds where you want to avoid remote downloads.
"""
inner: _PyHarmonyEncoding = _load_harmony_encoding_from_file(
name,
vocab_file,
special_tokens,
pattern,
n_ctx,
max_message_tokens,
max_action_length,
expected_hash,
)
return HarmonyEncoding(inner)


# For *mypy* we expose a minimal stub of the `HarmonyEncodingName` enum. At
# **runtime** the user is expected to pass the *string* names because the Rust
# side only operates on strings anyway.
Expand Down Expand Up @@ -718,4 +741,5 @@ def __str__(self) -> str: # noqa: D401
"StreamableParser",
"StreamState",
"HarmonyError",
"load_harmony_encoding_from_file",
]
25 changes: 25 additions & 0 deletions src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,31 @@ impl HarmonyEncoding {
})
.collect()
}

pub fn from_local_file(
name: String,
vocab_file: &std::path::Path,
expected_hash: Option<&str>,
special_tokens: impl IntoIterator<Item = (String, u32)>,
pattern: &str,
n_ctx: usize,
max_message_tokens: usize,
max_action_length: usize,
) -> anyhow::Result<Self> {
use crate::tiktoken_ext::public_encodings::load_encoding_from_file;
let bpe = load_encoding_from_file(vocab_file, expected_hash, special_tokens, pattern)?;
Ok(HarmonyEncoding {
name,
n_ctx,
max_message_tokens,
max_action_length,
tokenizer_name: vocab_file.display().to_string(),
tokenizer: std::sync::Arc::new(bpe),
format_token_mapping: Default::default(),
stop_formatting_tokens: Default::default(),
stop_formatting_tokens_for_assistant_actions: Default::default(),
})
}
}

// Methods for rendering conversations
Expand Down
32 changes: 32 additions & 0 deletions src/py_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,38 @@ fn openai_harmony(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
}
m.add_function(pyo3::wrap_pyfunction!(load_harmony_encoding_py, m)?)?;

// Convenience function to load a HarmonyEncoding from a local vocab file for offline
// scenarios or reproducible builds where remote download is not possible.
#[pyfunction(name = "load_harmony_encoding_from_file")]
fn load_harmony_encoding_from_file_py(
py: Python<'_>,
name: &str,
vocab_file: &str,
special_tokens: Vec<(String, u32)>,
pattern: &str,
n_ctx: usize,
max_message_tokens: usize,
max_action_length: usize,
expected_hash: Option<&str>,
) -> PyResult<Py<PyHarmonyEncoding>> {
let encoding = HarmonyEncoding::from_local_file(
name.to_string(),
std::path::Path::new(vocab_file),
expected_hash,
special_tokens,
pattern,
n_ctx,
max_message_tokens,
max_action_length,
)
.map_err(|e| PyErr::new::<HarmonyError, _>(e.to_string()))?;
Py::new(py, PyHarmonyEncoding { inner: encoding })
}
m.add_function(pyo3::wrap_pyfunction!(
load_harmony_encoding_from_file_py,
m
)?)?;

// Convenience functions to get the tool configs for the browser and python tools.
#[pyfunction]
fn get_tool_namespace_config(py: Python<'_>, tool: &str) -> PyResult<PyObject> {
Expand Down
2 changes: 1 addition & 1 deletion src/tiktoken_ext/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
mod public_encodings;
pub mod public_encodings;
pub use public_encodings::{set_tiktoken_base_url, Encoding};
65 changes: 49 additions & 16 deletions src/tiktoken_ext/public_encodings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::{
use base64::{prelude::BASE64_STANDARD, Engine as _};

use crate::tiktoken::{CoreBPE, Rank};
use fs2::FileExt;
use sha1::Sha1;
use sha2::{Digest as _, Sha256};

Expand Down Expand Up @@ -420,24 +421,35 @@ fn download_or_find_cached_file(
) -> Result<PathBuf, RemoteVocabFileError> {
let cache_dir = resolve_cache_dir()?;
let cache_path = resolve_cache_path(&cache_dir, url);
if cache_path.exists() {
if verify_file_hash(&cache_path, expected_hash)? {
return Ok(cache_path);
}
let _ = std::fs::remove_file(&cache_path);
}
let hash = load_remote_file(url, &cache_path)?;
if let Some(expected_hash) = expected_hash {
if hash != expected_hash {
let lock_path = cache_path.with_extension("lock");
let lock_file = File::create(&lock_path).map_err(|e| {
RemoteVocabFileError::IOError(format!("creating lock file {lock_path:?}"), e)
})?;
lock_file
.lock_exclusive()
.map_err(|e| RemoteVocabFileError::IOError(format!("locking file {lock_path:?}"), e))?;
let result = (|| {
if cache_path.exists() {
if verify_file_hash(&cache_path, expected_hash)? {
return Ok(cache_path);
}
let _ = std::fs::remove_file(&cache_path);
return Err(RemoteVocabFileError::HashMismatch {
file_url: url.to_string(),
expected_hash: expected_hash.to_string(),
computed_hash: hash,
});
}
}
Ok(cache_path)
let hash = load_remote_file(url, &cache_path)?;
if let Some(expected_hash) = expected_hash {
if hash != expected_hash {
let _ = std::fs::remove_file(&cache_path);
return Err(RemoteVocabFileError::HashMismatch {
file_url: url.to_string(),
expected_hash: expected_hash.to_string(),
computed_hash: hash,
});
}
}
Ok(cache_path)
})();
let _ = fs2::FileExt::unlock(&lock_file);
result
}

#[cfg(target_arch = "wasm32")]
Expand Down Expand Up @@ -572,4 +584,25 @@ mod tests {
let _ = encoding.load().unwrap();
}
}

#[test]
fn test_parallel_load_encodings() {
use std::thread;

let encodings = Encoding::all();
for encoding in encodings {
let name = encoding.name();
let handles: Vec<_> = (0..8)
.map(|_| {
let name = name.to_string();
thread::spawn(move || {
Encoding::from_name(&name).unwrap().load().unwrap();
})
})
.collect();
for handle in handles {
handle.join().expect("Thread panicked");
}
}
}
}
65 changes: 65 additions & 0 deletions tests/test_harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
SystemContent,
ToolDescription,
load_harmony_encoding,
load_harmony_encoding_from_file,
)
from pydantic import ValidationError

Expand Down Expand Up @@ -949,3 +950,67 @@ def test_streamable_parser_tool_call_with_constrain_adjacent():
]

assert parser.messages == expected


def test_load_harmony_encoding_from_file(tmp_path):
import os
from openai_harmony import load_harmony_encoding_from_file

cache_dir = os.environ.get("TIKTOKEN_RS_CACHE_DIR")
if not cache_dir:
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "tiktoken-rs-cache")
import hashlib
url = "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken"
cache_key = hashlib.sha1(url.encode()).hexdigest()
vocab_file = os.path.join(cache_dir, cache_key)
if not os.path.exists(vocab_file):
import pytest
pytest.skip("No local vocab file available for offline test")

special_tokens = [
("<|startoftext|>", 199998),
("<|endoftext|>", 199999),
("<|reserved_200000|>", 200000),
("<|reserved_200001|>", 200001),
("<|return|>", 200002),
("<|constrain|>", 200003),
("<|reserved_200004|>", 200004),
("<|channel|>", 200005),
("<|start|>", 200006),
("<|end|>", 200007),
("<|message|>", 200008),
("<|reserved_200009|>", 200009),
("<|reserved_200010|>", 200010),
("<|reserved_200011|>", 200011),
("<|call|>", 200012),
("<|reserved_200013|>", 200013),
]
pattern = "|".join([
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
"[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
"\\p{N}{1,3}",
" ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*",
"\\s*[\\r\\n]+",
"\\s+(?!\\S)",
"\\s+",
])
n_ctx = 8192
max_message_tokens = 4096
max_action_length = 256
expected_hash = "446a9538cb6c348e3516120d7c08b09f57c36495e2acfffe59a5bf8b0cfb1a2d"

encoding = load_harmony_encoding_from_file(
name="test_local",
vocab_file=vocab_file,
special_tokens=special_tokens,
pattern=pattern,
n_ctx=n_ctx,
max_message_tokens=max_message_tokens,
max_action_length=max_action_length,
expected_hash=expected_hash,
)

text = "Hello world!"
tokens = encoding.encode(text)
decoded = encoding.decode(tokens)
assert decoded.startswith("Hello world")
Loading