Skip to content

Commit c5b9b0a

Browse files
authored
Merge pull request #26 from jordan-wu-97/jordan/fix-function-call-atomic-bool
fix: make `HarmonyEncoding` usable concurrently
2 parents e373981 + adce02f commit c5b9b0a

File tree

6 files changed

+130
-34
lines changed

6 files changed

+130
-34
lines changed

python/openai_harmony/__init__.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,10 @@ class RenderConversationConfig(BaseModel):
425425
auto_drop_analysis: bool = True
426426

427427

428+
class RenderOptions(BaseModel):
429+
conversation_has_function_tools: bool = False
430+
431+
428432
class HarmonyEncoding:
429433
"""High-level wrapper around the Rust ``PyHarmonyEncoding`` class."""
430434

@@ -498,9 +502,20 @@ def render_conversation_for_training(
498502
config=config_dict,
499503
)
500504

501-
def render(self, message: Message) -> List[int]:
505+
def render(
506+
self, message: Message, render_options: Optional[RenderOptions] = None
507+
) -> List[int]:
502508
"""Render a single message into tokens."""
503-
return self._inner.render(message_json=message.to_json())
509+
if render_options is None:
510+
render_options_dict = {"conversation_has_function_tools": False}
511+
else:
512+
render_options_dict = {
513+
"conversation_has_function_tools": render_options.conversation_has_function_tools
514+
}
515+
516+
return self._inner.render(
517+
message_json=message.to_json(), render_options=render_options_dict
518+
)
504519

505520
# -- Parsing -------------------------------------------------------
506521

src/encoding.rs

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@ use crate::{
55
use anyhow::Context as _;
66
use std::{
77
collections::{HashMap, HashSet},
8-
sync::{
9-
atomic::{AtomicBool, Ordering},
10-
Arc,
11-
},
8+
sync::Arc,
129
vec,
1310
};
1411

@@ -92,7 +89,6 @@ pub struct HarmonyEncoding {
9289
pub(crate) format_token_mapping: HashMap<FormattingToken, String>,
9390
pub(crate) stop_formatting_tokens: HashSet<FormattingToken>,
9491
pub(crate) stop_formatting_tokens_for_assistant_actions: HashSet<FormattingToken>,
95-
pub(crate) conversation_has_function_tools: Arc<AtomicBool>,
9692
}
9793

9894
impl std::fmt::Debug for HarmonyEncoding {
@@ -191,8 +187,9 @@ impl HarmonyEncoding {
191187
}
192188
})
193189
});
194-
self.conversation_has_function_tools
195-
.store(has_function_tools, Ordering::Relaxed);
190+
let render_options = RenderOptions {
191+
conversation_has_function_tools: has_function_tools,
192+
};
196193
let last_assistant_is_final = messages
197194
.iter()
198195
.rev()
@@ -217,9 +214,7 @@ impl HarmonyEncoding {
217214
&& first_final_idx.is_some_and(|first| *idx < first)
218215
&& msg.channel.as_deref() == Some("analysis"))
219216
})
220-
.try_for_each(|(_, msg)| self.render_into(msg, into));
221-
self.conversation_has_function_tools
222-
.store(false, Ordering::Relaxed);
217+
.try_for_each(|(_, msg)| self.render_into(msg, into, Some(&render_options)));
223218
result?;
224219
Ok(())
225220
}
@@ -305,18 +300,27 @@ impl HarmonyEncoding {
305300
}
306301

307302
/// Render a single message into tokens.
308-
pub fn render(&self, message: &Message) -> anyhow::Result<Vec<Rank>> {
303+
pub fn render(
304+
&self,
305+
message: &Message,
306+
render_options: Option<&RenderOptions>,
307+
) -> anyhow::Result<Vec<Rank>> {
309308
let mut out = vec![];
310-
Render::<Message>::render(self, message, &mut out)?;
309+
Render::<Message>::render(self, message, &mut out, render_options)?;
311310
Ok(out)
312311
}
313312

314313
/// Render a single message into the provided buffer.
315-
pub fn render_into<B>(&self, message: &Message, into: &mut B) -> anyhow::Result<()>
314+
pub fn render_into<B>(
315+
&self,
316+
message: &Message,
317+
into: &mut B,
318+
render_options: Option<&RenderOptions>,
319+
) -> anyhow::Result<()>
316320
where
317321
B: Extend<Rank>,
318322
{
319-
Render::<Message>::render(self, message, into)
323+
Render::<Message>::render(self, message, into, render_options)
320324
}
321325
}
322326

@@ -772,14 +776,29 @@ impl HarmonyEncoding {
772776
}
773777
}
774778

779+
#[derive(Clone, Copy, Debug, Default)]
780+
pub struct RenderOptions {
781+
pub conversation_has_function_tools: bool,
782+
}
783+
775784
trait Render<T: ?Sized> {
776-
fn render<B>(&self, item: &T, into: &mut B) -> anyhow::Result<()>
785+
fn render<B>(
786+
&self,
787+
item: &T,
788+
into: &mut B,
789+
render_options: Option<&RenderOptions>,
790+
) -> anyhow::Result<()>
777791
where
778792
B: Extend<Rank>;
779793
}
780794

781795
impl Render<Message> for HarmonyEncoding {
782-
fn render<B>(&self, message: &Message, into: &mut B) -> anyhow::Result<()>
796+
fn render<B>(
797+
&self,
798+
message: &Message,
799+
into: &mut B,
800+
render_options: Option<&RenderOptions>,
801+
) -> anyhow::Result<()>
783802
where
784803
B: Extend<Rank>,
785804
{
@@ -836,7 +855,7 @@ impl Render<Message> for HarmonyEncoding {
836855
message.author.role
837856
);
838857
}
839-
Render::<Content>::render(self, content, into)?;
858+
Render::<Content>::render(self, content, into, render_options)?;
840859
}
841860

842861
// If there is a tool call we should render a tool call token
@@ -851,23 +870,35 @@ impl Render<Message> for HarmonyEncoding {
851870

852871
// Dispatch Content variants to their specific Render implementations
853872
impl Render<Content> for HarmonyEncoding {
854-
fn render<B>(&self, content: &Content, into: &mut B) -> anyhow::Result<()>
873+
fn render<B>(
874+
&self,
875+
content: &Content,
876+
into: &mut B,
877+
render_options: Option<&RenderOptions>,
878+
) -> anyhow::Result<()>
855879
where
856880
B: Extend<Rank>,
857881
{
858882
match content {
859-
Content::Text(text) => Render::<TextContent>::render(self, text, into),
860-
Content::SystemContent(sys) => Render::<SystemContent>::render(self, sys, into),
883+
Content::Text(text) => Render::<TextContent>::render(self, text, into, render_options),
884+
Content::SystemContent(sys) => {
885+
Render::<SystemContent>::render(self, sys, into, render_options)
886+
}
861887
Content::DeveloperContent(dev) => {
862-
Render::<crate::chat::DeveloperContent>::render(self, dev, into)
888+
Render::<crate::chat::DeveloperContent>::render(self, dev, into, render_options)
863889
}
864890
}
865891
}
866892
}
867893

868894
// Render plain text content
869895
impl Render<TextContent> for HarmonyEncoding {
870-
fn render<B>(&self, text: &TextContent, into: &mut B) -> anyhow::Result<()>
896+
fn render<B>(
897+
&self,
898+
text: &TextContent,
899+
into: &mut B,
900+
_render_options: Option<&RenderOptions>,
901+
) -> anyhow::Result<()>
871902
where
872903
B: Extend<Rank>,
873904
{
@@ -877,7 +908,12 @@ impl Render<TextContent> for HarmonyEncoding {
877908

878909
// Render system-specific content (model identity, instructions, effort)
879910
impl Render<SystemContent> for HarmonyEncoding {
880-
fn render<B>(&self, sys: &SystemContent, into: &mut B) -> anyhow::Result<()>
911+
fn render<B>(
912+
&self,
913+
sys: &SystemContent,
914+
into: &mut B,
915+
render_options: Option<&RenderOptions>,
916+
) -> anyhow::Result<()>
881917
where
882918
B: Extend<Rank>,
883919
{
@@ -923,7 +959,7 @@ impl Render<SystemContent> for HarmonyEncoding {
923959
if channel_config.channel_required {
924960
channels_header.push_str(" Channel must be included for every message.");
925961
}
926-
if self.conversation_has_function_tools.load(Ordering::Relaxed) {
962+
if render_options.is_some_and(|o| o.conversation_has_function_tools) {
927963
channels_header.push('\n');
928964
channels_header.push_str(
929965
"Calls to these tools must go to the commentary channel: 'functions'.",
@@ -940,7 +976,12 @@ impl Render<SystemContent> for HarmonyEncoding {
940976

941977
// Render developer-specific content (instructions, tools)
942978
impl Render<crate::chat::DeveloperContent> for HarmonyEncoding {
943-
fn render<B>(&self, dev: &crate::chat::DeveloperContent, into: &mut B) -> anyhow::Result<()>
979+
fn render<B>(
980+
&self,
981+
dev: &crate::chat::DeveloperContent,
982+
into: &mut B,
983+
_render_options: Option<&RenderOptions>,
984+
) -> anyhow::Result<()>
944985
where
945986
B: Extend<Rank>,
946987
{

src/py_module.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,29 @@ impl PyHarmonyEncoding {
178178
}
179179

180180
/// Render a single message into tokens.
181-
fn render(&self, message_json: &str) -> PyResult<Vec<u32>> {
181+
fn render(
182+
&self,
183+
message_json: &str,
184+
render_options: Option<Bound<'_, PyDict>>,
185+
) -> PyResult<Vec<u32>> {
182186
let message: crate::chat::Message = serde_json::from_str(message_json).map_err(|e| {
183187
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("invalid message JSON: {e}"))
184188
})?;
185189

190+
let rust_options = if let Some(options_dict) = render_options {
191+
let conversation_has_function_tools = options_dict
192+
.get_item("conversation_has_function_tools")?
193+
.and_then(|v| v.extract().ok())
194+
.unwrap_or(false);
195+
Some(crate::encoding::RenderOptions {
196+
conversation_has_function_tools,
197+
})
198+
} else {
199+
None
200+
};
201+
186202
self.inner
187-
.render(&message)
203+
.render(&message, rust_options.as_ref())
188204
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
189205
}
190206

src/registry.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::{
22
collections::{HashMap, HashSet},
3-
sync::{atomic::AtomicBool, Arc},
3+
sync::Arc,
44
};
55

66
use crate::{
@@ -76,7 +76,6 @@ pub fn load_harmony_encoding(name: HarmonyEncodingName) -> anyhow::Result<Harmon
7676
FormattingToken::EndMessageDoneSampling,
7777
FormattingToken::EndMessageAssistantToTool,
7878
]),
79-
conversation_has_function_tools: Arc::new(AtomicBool::new(false)),
8079
})
8180
}
8281
}

src/tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ fn test_render_and_render_conversation_roundtrip() {
525525
let encoding = load_harmony_encoding(HarmonyEncodingName::HarmonyGptOss).unwrap();
526526
let msg = Message::from_role_and_content(Role::User, "Hello");
527527
let convo = Conversation::from_messages([msg.clone()]);
528-
let tokens_msg = encoding.render(&msg).unwrap();
528+
let tokens_msg = encoding.render(&msg, None).unwrap();
529529
let tokens_convo = encoding.render_conversation(&convo, None).unwrap();
530530
assert_eq!(tokens_msg, tokens_convo);
531531
let tokens_completion = encoding

src/wasm_module.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ extern "C" {
1818

1919
#[wasm_bindgen(typescript_type = "RenderConversationConfig")]
2020
pub type JsRenderConversationConfig;
21+
22+
#[wasm_bindgen(typescript_type = "RenderOptions")]
23+
pub type JsRenderOptions;
2124
}
2225

2326
#[wasm_bindgen(typescript_custom_section)]
@@ -127,12 +130,34 @@ impl JsHarmonyEncoding {
127130
}
128131

129132
#[wasm_bindgen]
130-
pub fn render(&self, message: JsMessage) -> Result<Vec<u32>, JsValue> {
133+
pub fn render(
134+
&self,
135+
message: JsMessage,
136+
render_options: JsRenderOptions,
137+
) -> Result<Vec<u32>, JsValue> {
131138
let message: JsValue = message.into();
132139
let message: crate::chat::Message = serde_wasm_bindgen::from_value(message)
133140
.map_err(|e| JsValue::from_str(&format!("invalid message JSON: {e}")))?;
141+
142+
#[derive(Deserialize)]
143+
struct RenderOptions {
144+
conversation_has_function_tools: Option<bool>,
145+
}
146+
let render_options: JsValue = render_options.into();
147+
let rust_options = if render_options.is_undefined() || render_options.is_null() {
148+
None
149+
} else {
150+
let cfg: RenderOptions = serde_wasm_bindgen::from_value(render_options)
151+
.map_err(|e| JsValue::from_str(&format!("invalid render options: {e}")))?;
152+
Some(crate::encoding::RenderOptions {
153+
conversation_has_function_tools: cfg
154+
.conversation_has_function_tools
155+
.unwrap_or(false),
156+
})
157+
};
158+
134159
self.inner
135-
.render(&message)
160+
.render(&message, rust_options.as_ref())
136161
.map_err(|e| JsValue::from_str(&e.to_string()))
137162
}
138163

0 commit comments

Comments
 (0)