Skip to content

Commit d23a666

Browse files
authored
Add together fine tuning config to python client (tensorzero#3105)
* update rust * update python client * update error string * fix naming error * throw error for credentials if json parsing fails * add together mock inference provider * update github workflows config
1 parent dab844c commit d23a666

File tree

14 files changed

+356
-23
lines changed

14 files changed

+356
-23
lines changed

.github/workflows/general.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ jobs:
533533
OPENAI_API_KEY: not_used
534534
FIREWORKS_API_KEY: not_used
535535
FIREWORKS_ACCOUNT_ID: not_used
536+
TOGETHER_API_KEY: not_used
536537
TENSORZERO_USE_MOCK_INFERENCE_PROVIDER: 1
537538
TENSORZERO_SKIP_LARGE_FIXTURES: 1
538539
R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }}

clients/python/src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ use tensorzero_core::{
4040
optimization::{
4141
fireworks_sft::UninitializedFireworksSFTConfig,
4242
gcp_vertex_gemini_sft::UninitializedGCPVertexGeminiSFTConfig,
43-
openai_sft::UninitializedOpenAISFTConfig, OptimizationJobInfoPyClass,
44-
OptimizationJobStatus, UninitializedOptimizerInfo,
43+
openai_sft::UninitializedOpenAISFTConfig, together_sft::UninitializedTogetherSFTConfig,
44+
OptimizationJobInfoPyClass, OptimizationJobStatus, UninitializedOptimizerInfo,
4545
},
4646
variant::{
4747
BestOfNSamplingConfigPyClass, ChainOfThoughtConfigPyClass, ChatCompletionConfigPyClass,
@@ -96,6 +96,7 @@ fn tensorzero(m: &Bound<'_, PyModule>) -> PyResult<()> {
9696
m.add_class::<UninitializedOpenAISFTConfig>()?;
9797
m.add_class::<UninitializedFireworksSFTConfig>()?;
9898
m.add_class::<UninitializedGCPVertexGeminiSFTConfig>()?;
99+
m.add_class::<UninitializedTogetherSFTConfig>()?;
99100
m.add_class::<Datapoint>()?;
100101
m.add_class::<ResolvedInput>()?;
101102
m.add_class::<ResolvedInputMessage>()?;

clients/python/tensorzero/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ResolvedInput,
2626
ResolvedInputMessage,
2727
StoredInference,
28+
TogetherSFTConfig,
2829
VariantsConfig,
2930
)
3031
from .tensorzero import (
@@ -89,7 +90,7 @@
8990
ChatDatapoint = Datapoint.Chat
9091
JsonDatapoint = Datapoint.Json
9192

92-
OptimizationConfig = t.Union[OpenAISFTConfig, FireworksSFTConfig]
93+
OptimizationConfig = t.Union[OpenAISFTConfig, FireworksSFTConfig, TogetherSFTConfig]
9394
ChatInferenceOutput = t.List[ContentBlock]
9495

9596

@@ -166,6 +167,7 @@
166167
"Thought",
167168
"ThoughtChunk",
168169
"TimeFilter",
170+
"TogetherSFTConfig",
169171
"Tool",
170172
"ToolChoice",
171173
"ToolParams",

clients/python/tensorzero/tensorzero.pyi

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,16 @@ class GCPVertexGeminiSFTConfig:
202202
bucket_path_prefix: Optional[str] = None,
203203
) -> None: ...
204204

205+
@final
206+
class TogetherSFTConfig:
207+
def __init__(
208+
self,
209+
*,
210+
model: str,
211+
credentials: Optional[str] = None,
212+
api_base: Optional[str] = None,
213+
) -> None: ...
214+
205215
@final
206216
class Datapoint:
207217
Chat: Type["Datapoint"]
@@ -1023,6 +1033,7 @@ __all__ = [
10231033
"OptimizationJobInfo",
10241034
"OptimizationJobStatus",
10251035
"RenderedSample",
1036+
"TogetherSFTConfig",
10261037
"StoredInference",
10271038
"ResolvedInput",
10281039
"ResolvedInputMessage",

clients/python/tests/test_optimization.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
OptimizationJobStatus,
1010
RenderedSample,
1111
TensorZeroGateway,
12+
TogetherSFTConfig,
1213
)
1314

1415

@@ -57,6 +58,28 @@ def test_sync_fireworks_sft(
5758
sleep(1)
5859

5960

61+
def test_sync_together_sft(
62+
embedded_sync_client: TensorZeroGateway,
63+
mixed_rendered_samples: List[RenderedSample],
64+
):
65+
optimization_config = TogetherSFTConfig(
66+
model="meta-llama/Meta-Llama-3.1-8B-Instruct-Reference",
67+
api_base="http://localhost:3030/together/",
68+
)
69+
optimization_job_handle = embedded_sync_client.experimental_launch_optimization(
70+
train_samples=mixed_rendered_samples,
71+
val_samples=None,
72+
optimization_config=optimization_config,
73+
)
74+
while True:
75+
job_info = embedded_sync_client.experimental_poll_optimization(
76+
job_handle=optimization_job_handle
77+
)
78+
if job_info.status == OptimizationJobStatus.Completed:
79+
break
80+
sleep(1)
81+
82+
6083
@pytest.mark.asyncio
6184
async def test_async_openai_sft(
6285
embedded_async_client: AsyncTensorZeroGateway,
@@ -105,3 +128,28 @@ async def test_async_fireworks_sft(
105128
if job_info.status == OptimizationJobStatus.Completed:
106129
break
107130
sleep(1)
131+
132+
133+
@pytest.mark.asyncio
134+
async def test_async_together_sft(
135+
embedded_async_client: AsyncTensorZeroGateway,
136+
mixed_rendered_samples: List[RenderedSample],
137+
):
138+
optimization_config = TogetherSFTConfig(
139+
model="meta-llama/Meta-Llama-3.1-8B-Instruct-Reference",
140+
api_base="http://localhost:3030/together/",
141+
)
142+
optimization_job_handle = (
143+
await embedded_async_client.experimental_launch_optimization(
144+
train_samples=mixed_rendered_samples,
145+
val_samples=None,
146+
optimization_config=optimization_config,
147+
)
148+
)
149+
while True:
150+
job_info = await embedded_async_client.experimental_poll_optimization(
151+
job_handle=optimization_job_handle
152+
)
153+
if job_info.status == OptimizationJobStatus.Completed:
154+
break
155+
sleep(1)

tensorzero-core/src/inference/types/pyo3_helpers.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::endpoints::datasets::Datapoint;
1212
use crate::inference::types::{ContentBlockChatOutput, ResolvedInput, ResolvedInputMessageContent};
1313
use crate::optimization::fireworks_sft::UninitializedFireworksSFTConfig;
1414
use crate::optimization::openai_sft::UninitializedOpenAISFTConfig;
15+
use crate::optimization::together_sft::UninitializedTogetherSFTConfig;
1516
use crate::optimization::UninitializedOptimizerConfig;
1617
use crate::stored_inference::{
1718
RenderedSample, SimpleStoredSampleInfo, StoredInference, StoredSample,
@@ -323,9 +324,11 @@ pub fn deserialize_optimization_config(
323324
Ok(UninitializedOptimizerConfig::OpenAISFT(obj.extract()?))
324325
} else if obj.is_instance_of::<UninitializedFireworksSFTConfig>() {
325326
Ok(UninitializedOptimizerConfig::FireworksSFT(obj.extract()?))
327+
} else if obj.is_instance_of::<UninitializedTogetherSFTConfig>() {
328+
Ok(UninitializedOptimizerConfig::TogetherSFT(obj.extract()?))
326329
} else {
327330
Err(PyValueError::new_err(
328-
"Invalid optimization config. Expected OpenAISFTConfig or FireworksSFTConfig",
331+
"Invalid optimization config. Expected OpenAISFTConfig, FireworksSFTConfig, or TogetherSFTConfig",
329332
))
330333
}
331334
}

tensorzero-core/src/optimization/fireworks_sft/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,11 @@ impl UninitializedFireworksSFTConfig {
212212
account_id: String,
213213
api_base: Option<String>,
214214
) -> PyResult<Self> {
215-
let credentials =
216-
credentials.map(|s| serde_json::from_str(&s).unwrap_or(CredentialLocation::Env(s)));
215+
let credentials = credentials
216+
.map(|s| serde_json::from_str(&s))
217+
.transpose()
218+
.map_err(|e| PyErr::new::<PyValueError, _>(format!("Invalid credentials JSON: {e}")))?
219+
.or_else(|| Some(default_api_key_location()));
217220
let api_base = api_base
218221
.map(|s| {
219222
Url::parse(&s)

tensorzero-core/src/optimization/gcp_vertex_gemini_sft/mod.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ use crate::{
2424
stored_inference::RenderedSample,
2525
};
2626

27-
#[cfg(feature = "pyo3")]
28-
use crate::inference::types::pyo3_helpers::tensorzero_core_error;
29-
3027
pub fn gcp_vertex_gemini_base_url(project_id: &str, region: &str) -> Result<Url, url::ParseError> {
3128
let subdomain_prefix = location_subdomain_prefix(region);
3229
Url::parse(&format!(
@@ -95,7 +92,7 @@ impl UninitializedGCPVertexGeminiSFTConfig {
9592
#[new]
9693
#[pyo3(signature = (*, model, bucket_name, project_id, region, learning_rate_multiplier=None, adapter_size=None, n_epochs=None, export_last_checkpoint_only=None, credentials=None, api_base=None, seed=None, service_account=None, kms_key_name=None, tuned_model_display_name=None, bucket_path_prefix=None))]
9794
pub fn new(
98-
py: Python<'_>,
95+
_py: Python<'_>,
9996
model: String,
10097
bucket_name: String,
10198
project_id: String,
@@ -113,13 +110,11 @@ impl UninitializedGCPVertexGeminiSFTConfig {
113110
bucket_path_prefix: Option<String>,
114111
) -> PyResult<Self> {
115112
// Use Deserialize to convert the string to a CredentialLocation
116-
let credentials = match credentials {
117-
Some(s) => match serde_json::from_str(&s) {
118-
Ok(parsed) => Some(parsed),
119-
Err(e) => return Err(tensorzero_core_error(py, &e.to_string())?),
120-
},
121-
None => None,
122-
};
113+
let credentials = credentials
114+
.map(|s| serde_json::from_str(&s))
115+
.transpose()
116+
.map_err(|e| PyErr::new::<PyValueError, _>(format!("Invalid credentials JSON: {e}")))?
117+
.or_else(|| Some(default_api_key_location()));
123118
let api_base = api_base
124119
.map(|s| {
125120
Url::parse(&s)

tensorzero-core/src/optimization/openai_sft/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,11 @@ impl UninitializedOpenAISFTConfig {
9090
suffix: Option<String>,
9191
) -> PyResult<Self> {
9292
// Use Deserialize to convert the string to a CredentialLocation
93-
let credentials =
94-
credentials.map(|s| serde_json::from_str(&s).unwrap_or(CredentialLocation::Env(s)));
93+
let credentials = credentials
94+
.map(|s| serde_json::from_str(&s))
95+
.transpose()
96+
.map_err(|e| PyErr::new::<PyValueError, _>(format!("Invalid credentials JSON: {e}")))?
97+
.or_else(|| Some(default_api_key_location()));
9598
let api_base = api_base
9699
.map(|s| {
97100
Url::parse(&s)

tensorzero-core/src/optimization/together_sft/mod.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#[cfg(feature = "pyo3")]
2+
use pyo3::exceptions::PyValueError;
3+
#[cfg(feature = "pyo3")]
24
use pyo3::prelude::*;
35
use std::borrow::Cow;
46
use std::collections::HashMap;
@@ -66,13 +68,21 @@ impl std::fmt::Display for TogetherSFTJobHandle {
6668
#[cfg_attr(test, derive(ts_rs::TS))]
6769
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
6870
#[cfg_attr(test, ts(export))]
71+
#[cfg_attr(feature = "pyo3", pyclass(str, name = "TogetherSFTConfig"))]
6972
pub struct UninitializedTogetherSFTConfig {
7073
pub model: String,
7174
#[cfg_attr(test, ts(type = "string | null"))]
7275
pub credentials: Option<CredentialLocation>,
7376
pub api_base: Option<Url>,
7477
}
7578

79+
impl std::fmt::Display for UninitializedTogetherSFTConfig {
80+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81+
let json = serde_json::to_string_pretty(self).map_err(|_| std::fmt::Error)?;
82+
write!(f, "{json}")
83+
}
84+
}
85+
7686
#[derive(Debug, Serialize)]
7787
pub struct TogetherSupervisedRow<'a> {
7888
messages: Vec<OpenAIRequestMessage<'a>>,
@@ -120,6 +130,57 @@ impl<'a> TryFrom<&'a RenderedSample> for TogetherSupervisedRow<'a> {
120130
}
121131
}
122132

133+
#[cfg(feature = "pyo3")]
134+
#[pymethods]
135+
impl UninitializedTogetherSFTConfig {
136+
// We allow too many arguments since it is a Python constructor
137+
/// NOTE: This signature currently does not work:
138+
/// print(TogetherSFTConfig.__init__.__text_signature__)
139+
/// prints out signature:
140+
/// ($self, /, *args, **kwargs)
141+
#[new]
142+
#[pyo3(signature = (*, model, credentials=None, api_base=None))]
143+
pub fn new(
144+
model: String,
145+
credentials: Option<String>,
146+
api_base: Option<String>,
147+
) -> PyResult<Self> {
148+
// Use Deserialize to convert the string to a CredentialLocation
149+
let credentials = credentials
150+
.map(|s| serde_json::from_str(&s))
151+
.transpose()
152+
.map_err(|e| PyErr::new::<PyValueError, _>(format!("Invalid credentials JSON: {e}")))?
153+
.or_else(|| Some(default_api_key_location()));
154+
let api_base = api_base
155+
.map(|s| {
156+
Url::parse(&s)
157+
.map_err(|e| PyErr::new::<PyValueError, std::string::String>(e.to_string()))
158+
})
159+
.transpose()?;
160+
Ok(Self {
161+
model,
162+
credentials,
163+
api_base,
164+
})
165+
}
166+
167+
/// Initialize the TogetherSFTConfig. All parameters are optional except for `model`.
168+
///
169+
/// :param model: The model to use for the fine-tuning job.
170+
/// :param credentials: The credentials to use for the fine-tuning job. This should be a string like "env::OPENAI_API_KEY". See docs for more details.
171+
/// :param api_base: The base URL to use for the fine-tuning job. This is primarily used for testing.
172+
#[expect(unused_variables)]
173+
#[pyo3(signature = (*, model, credentials=None, api_base=None))]
174+
fn __init__(
175+
this: Py<Self>,
176+
model: String,
177+
credentials: Option<String>,
178+
api_base: Option<String>,
179+
) -> Py<Self> {
180+
this
181+
}
182+
}
183+
123184
impl UninitializedTogetherSFTConfig {
124185
pub fn load(self) -> Result<TogetherSFTConfig, Error> {
125186
Ok(TogetherSFTConfig {

0 commit comments

Comments
 (0)