|
1 | 1 | #[cfg(feature = "pyo3")] |
| 2 | +use pyo3::exceptions::PyValueError; |
| 3 | +#[cfg(feature = "pyo3")] |
2 | 4 | use pyo3::prelude::*; |
3 | 5 | use std::borrow::Cow; |
4 | 6 | use std::collections::HashMap; |
@@ -66,13 +68,21 @@ impl std::fmt::Display for TogetherSFTJobHandle { |
66 | 68 | #[cfg_attr(test, derive(ts_rs::TS))] |
67 | 69 | #[derive(Clone, Debug, Default, Deserialize, Serialize)] |
68 | 70 | #[cfg_attr(test, ts(export))] |
| 71 | +#[cfg_attr(feature = "pyo3", pyclass(str, name = "TogetherSFTConfig"))] |
69 | 72 | pub struct UninitializedTogetherSFTConfig { |
70 | 73 | pub model: String, |
71 | 74 | #[cfg_attr(test, ts(type = "string | null"))] |
72 | 75 | pub credentials: Option<CredentialLocation>, |
73 | 76 | pub api_base: Option<Url>, |
74 | 77 | } |
75 | 78 |
|
| 79 | +impl std::fmt::Display for UninitializedTogetherSFTConfig { |
| 80 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 81 | + let json = serde_json::to_string_pretty(self).map_err(|_| std::fmt::Error)?; |
| 82 | + write!(f, "{json}") |
| 83 | + } |
| 84 | +} |
| 85 | + |
76 | 86 | #[derive(Debug, Serialize)] |
77 | 87 | pub struct TogetherSupervisedRow<'a> { |
78 | 88 | messages: Vec<OpenAIRequestMessage<'a>>, |
@@ -120,6 +130,57 @@ impl<'a> TryFrom<&'a RenderedSample> for TogetherSupervisedRow<'a> { |
120 | 130 | } |
121 | 131 | } |
122 | 132 |
|
| 133 | +#[cfg(feature = "pyo3")] |
| 134 | +#[pymethods] |
| 135 | +impl UninitializedTogetherSFTConfig { |
| 136 | + // We allow too many arguments since it is a Python constructor |
| 137 | + /// NOTE: This signature currently does not work: |
| 138 | + /// print(TogetherSFTConfig.__init__.__text_signature__) |
| 139 | + /// prints out signature: |
| 140 | + /// ($self, /, *args, **kwargs) |
| 141 | + #[new] |
| 142 | + #[pyo3(signature = (*, model, credentials=None, api_base=None))] |
| 143 | + pub fn new( |
| 144 | + model: String, |
| 145 | + credentials: Option<String>, |
| 146 | + api_base: Option<String>, |
| 147 | + ) -> PyResult<Self> { |
| 148 | + // Use Deserialize to convert the string to a CredentialLocation |
| 149 | + let credentials = credentials |
| 150 | + .map(|s| serde_json::from_str(&s)) |
| 151 | + .transpose() |
| 152 | + .map_err(|e| PyErr::new::<PyValueError, _>(format!("Invalid credentials JSON: {e}")))? |
| 153 | + .or_else(|| Some(default_api_key_location())); |
| 154 | + let api_base = api_base |
| 155 | + .map(|s| { |
| 156 | + Url::parse(&s) |
| 157 | + .map_err(|e| PyErr::new::<PyValueError, std::string::String>(e.to_string())) |
| 158 | + }) |
| 159 | + .transpose()?; |
| 160 | + Ok(Self { |
| 161 | + model, |
| 162 | + credentials, |
| 163 | + api_base, |
| 164 | + }) |
| 165 | + } |
| 166 | + |
| 167 | + /// Initialize the TogetherSFTConfig. All parameters are optional except for `model`. |
| 168 | + /// |
| 169 | + /// :param model: The model to use for the fine-tuning job. |
| 170 | + /// :param credentials: The credentials to use for the fine-tuning job. This should be a string like "env::OPENAI_API_KEY". See docs for more details. |
| 171 | + /// :param api_base: The base URL to use for the fine-tuning job. This is primarily used for testing. |
| 172 | + #[expect(unused_variables)] |
| 173 | + #[pyo3(signature = (*, model, credentials=None, api_base=None))] |
| 174 | + fn __init__( |
| 175 | + this: Py<Self>, |
| 176 | + model: String, |
| 177 | + credentials: Option<String>, |
| 178 | + api_base: Option<String>, |
| 179 | + ) -> Py<Self> { |
| 180 | + this |
| 181 | + } |
| 182 | +} |
| 183 | + |
123 | 184 | impl UninitializedTogetherSFTConfig { |
124 | 185 | pub fn load(self) -> Result<TogetherSFTConfig, Error> { |
125 | 186 | Ok(TogetherSFTConfig { |
|
0 commit comments