Skip to content

Commit 2412c2f

Browse files
nenad1002Prathik Rao
authored andcommitted
Rust bug fixes & changes (#560)
Part 1 of Rust changes (have part 2 but don't have time to test it now). This is mostly improving perf by reducing cloning and fixing some bugs + making code more readable (avoiding early returns).
1 parent ace1376 commit 2412c2f

File tree

9 files changed

+98
-85
lines changed

9 files changed

+98
-85
lines changed

sdk/rust/docs/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ pub struct Model { /* private fields */ }
149149
|--------|-----------|-------------|
150150
| `alias` | `fn alias(&self) -> &str` | Alias shared by all variants. |
151151
| `id` | `fn id(&self) -> &str` | Unique identifier of the selected variant. |
152-
| `variants` | `fn variants(&self) -> &[ModelVariant]` | All variants in this model. |
152+
| `variants` | `fn variants(&self) -> &[Arc<ModelVariant>]` | All variants in this model. |
153153
| `selected_variant` | `fn selected_variant(&self) -> &ModelVariant` | Currently selected variant. |
154154
| `select_variant` | `fn select_variant(&self, id: &str) -> Result<(), FoundryLocalError>` | Select a variant by id. |
155155
| `is_cached` | `async fn is_cached(&self) -> Result<bool, FoundryLocalError>` | Whether the selected variant is cached on disk. |

sdk/rust/src/catalog.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ impl Catalog {
135135
self.update_models().await?;
136136
let s = self.lock_state()?;
137137
s.models_by_alias.get(alias).cloned().ok_or_else(|| {
138-
let available: Vec<&String> = s.models_by_alias.keys().collect();
138+
let available: Vec<&str> = s.models_by_alias.keys().map(|k| k.as_str()).collect();
139139
FoundryLocalError::ModelOperation {
140140
reason: format!("Unknown model alias '{alias}'. Available: {available:?}"),
141141
}
@@ -152,7 +152,7 @@ impl Catalog {
152152
self.update_models().await?;
153153
let s = self.lock_state()?;
154154
s.variants_by_id.get(id).cloned().ok_or_else(|| {
155-
let available: Vec<&String> = s.variants_by_id.keys().collect();
155+
let available: Vec<&str> = s.variants_by_id.keys().map(|k| k.as_str()).collect();
156156
FoundryLocalError::ModelOperation {
157157
reason: format!("Unknown variant id '{id}'. Available: {available:?}"),
158158
}
@@ -216,18 +216,17 @@ impl Catalog {
216216
for info in infos {
217217
let id = info.id.clone();
218218
let alias = info.alias.clone();
219-
let variant = ModelVariant::new(
219+
let variant = Arc::new(ModelVariant::new(
220220
info,
221221
Arc::clone(&self.core),
222222
Arc::clone(&self.model_load_manager),
223223
self.invalidator.clone(),
224-
);
225-
let variant_arc = Arc::new(variant.clone());
226-
id_map.insert(id, variant_arc);
224+
));
225+
id_map.insert(id, Arc::clone(&variant));
227226

228227
alias_map_build
229-
.entry(alias.clone())
230-
.or_insert_with(|| Model::new(alias, Arc::clone(&self.core)))
228+
.entry(alias)
229+
.or_insert_with_key(|a| Model::new(a.clone(), Arc::clone(&self.core)))
231230
.add_variant(variant);
232231
}
233232

sdk/rust/src/configuration.rs

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -183,31 +183,24 @@ impl Configuration {
183183
let mut params = HashMap::new();
184184
params.insert("AppName".into(), app_name);
185185

186-
if let Some(v) = config.app_data_dir {
187-
params.insert("AppDataDir".into(), v);
188-
}
189-
if let Some(v) = config.model_cache_dir {
190-
params.insert("ModelCacheDir".into(), v);
191-
}
192-
if let Some(v) = config.logs_dir {
193-
params.insert("LogsDir".into(), v);
194-
}
195-
if let Some(level) = config.log_level {
196-
params.insert("LogLevel".into(), level.as_core_str().into());
197-
}
198-
if let Some(v) = config.web_service_urls {
199-
params.insert("WebServiceUrls".into(), v);
200-
}
201-
if let Some(v) = config.service_endpoint {
202-
params.insert("WebServiceExternalUrl".into(), v);
203-
}
204-
if let Some(v) = config.library_path {
205-
params.insert("FoundryLocalCorePath".into(), v);
186+
let optional_fields = [
187+
("AppDataDir", config.app_data_dir),
188+
("ModelCacheDir", config.model_cache_dir),
189+
("LogsDir", config.logs_dir),
190+
("LogLevel", config.log_level.map(|l| l.as_core_str().into())),
191+
("WebServiceUrls", config.web_service_urls),
192+
("WebServiceExternalUrl", config.service_endpoint),
193+
("FoundryLocalCorePath", config.library_path),
194+
];
195+
196+
for (key, value) in optional_fields {
197+
if let Some(v) = value {
198+
params.insert(key.into(), v);
199+
}
206200
}
201+
207202
if let Some(extra) = config.additional_settings {
208-
for (k, v) in extra {
209-
params.insert(k, v);
210-
}
203+
params.extend(extra);
211204
}
212205

213206
Ok((Self { params }, config.logger))

sdk/rust/src/detail/core_interop.rs

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -137,25 +137,42 @@ impl<'a> StreamingCallbackState<'a> {
137137

138138
/// Append raw bytes, decode as much valid UTF-8 as possible, and forward
139139
/// complete text to the callback. Any trailing incomplete multi-byte
140-
/// sequence is kept in the buffer for the next call.
140+
/// sequence is kept in the buffer for the next call. Invalid byte
141+
/// sequences are skipped to prevent the buffer from growing unboundedly.
141142
fn push(&mut self, bytes: &[u8]) {
142143
self.buf.extend_from_slice(bytes);
143-
let valid_up_to = match std::str::from_utf8(&self.buf) {
144-
Ok(s) => {
145-
(self.callback)(s);
146-
s.len()
147-
}
148-
Err(e) => {
149-
let n = e.valid_up_to();
150-
if n > 0 {
151-
// SAFETY: `valid_up_to` guarantees this prefix is valid UTF-8.
152-
let valid = unsafe { std::str::from_utf8_unchecked(&self.buf[..n]) };
153-
(self.callback)(valid);
144+
loop {
145+
match std::str::from_utf8(&self.buf) {
146+
Ok(s) => {
147+
if !s.is_empty() {
148+
(self.callback)(s);
149+
}
150+
self.buf.clear();
151+
break;
152+
}
153+
Err(e) => {
154+
let n = e.valid_up_to();
155+
if n > 0 {
156+
// SAFETY: `valid_up_to` guarantees this prefix is valid UTF-8.
157+
let valid = unsafe { std::str::from_utf8_unchecked(&self.buf[..n]) };
158+
(self.callback)(valid);
159+
}
160+
match e.error_len() {
161+
Some(err_len) => {
162+
// Definite invalid sequence — skip past it and
163+
// continue decoding the remainder.
164+
self.buf.drain(..n + err_len);
165+
}
166+
None => {
167+
// Incomplete multi-byte sequence at the end —
168+
// keep it for the next push.
169+
self.buf.drain(..n);
170+
break;
171+
}
172+
}
154173
}
155-
n
156174
}
157-
};
158-
self.buf.drain(..valid_up_to);
175+
}
159176
}
160177

161178
/// Flush any remaining bytes as lossy UTF-8 (called once after the native

sdk/rust/src/detail/model_load_manager.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,27 +34,27 @@ impl ModelLoadManager {
3434
let encoded_id = urlencoding::encode(model_id);
3535
self.http_get(&format!("{base_url}/models/load/{encoded_id}"))
3636
.await?;
37-
return Ok(());
37+
} else {
38+
let params = json!({ "Params": { "Model": model_id } });
39+
self.core
40+
.execute_command_async("load_model".into(), Some(params))
41+
.await?;
3842
}
39-
let params = json!({ "Params": { "Model": model_id } });
40-
self.core
41-
.execute_command_async("load_model".into(), Some(params))
42-
.await?;
4343
Ok(())
4444
}
4545

4646
/// Unload a previously loaded model.
4747
pub async fn unload(&self, model_id: &str) -> Result<String> {
4848
if let Some(base_url) = &self.external_service_url {
4949
let encoded_id = urlencoding::encode(model_id);
50-
return self
51-
.http_get(&format!("{base_url}/models/unload/{encoded_id}"))
52-
.await;
50+
self.http_get(&format!("{base_url}/models/unload/{encoded_id}"))
51+
.await
52+
} else {
53+
let params = json!({ "Params": { "Model": model_id } });
54+
self.core
55+
.execute_command_async("unload_model".into(), Some(params))
56+
.await
5357
}
54-
let params = json!({ "Params": { "Model": model_id } });
55-
self.core
56-
.execute_command_async("unload_model".into(), Some(params))
57-
.await
5858
}
5959

6060
/// Return the list of currently loaded model identifiers.
@@ -67,11 +67,11 @@ impl ModelLoadManager {
6767
.await?
6868
};
6969

70-
if raw.trim().is_empty() {
71-
return Ok(Vec::new());
72-
}
73-
74-
let ids: Vec<String> = serde_json::from_str(&raw)?;
70+
let ids: Vec<String> = if raw.trim().is_empty() {
71+
Vec::new()
72+
} else {
73+
serde_json::from_str(&raw)?
74+
};
7575
Ok(ids)
7676
}
7777

sdk/rust/src/model.rs

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::openai::ChatClient;
1919
pub struct Model {
2020
alias: String,
2121
core: Arc<CoreInterop>,
22-
variants: Vec<ModelVariant>,
22+
variants: Vec<Arc<ModelVariant>>,
2323
selected_index: AtomicUsize,
2424
}
2525

@@ -57,7 +57,7 @@ impl Model {
5757

5858
/// Add a variant. If the new variant is cached and the current selection
5959
/// is not, the new variant becomes the selected one.
60-
pub(crate) fn add_variant(&mut self, variant: ModelVariant) {
60+
pub(crate) fn add_variant(&mut self, variant: Arc<ModelVariant>) {
6161
self.variants.push(variant);
6262
let new_idx = self.variants.len() - 1;
6363
let current = self.selected_index.load(Relaxed);
@@ -70,17 +70,21 @@ impl Model {
7070

7171
/// Select a variant by its unique id.
7272
pub fn select_variant(&self, id: &str) -> Result<()> {
73-
if let Some(pos) = self.variants.iter().position(|v| v.id() == id) {
74-
self.selected_index.store(pos, Relaxed);
75-
return Ok(());
73+
match self.variants.iter().position(|v| v.id() == id) {
74+
Some(pos) => {
75+
self.selected_index.store(pos, Relaxed);
76+
Ok(())
77+
}
78+
None => {
79+
let available: Vec<&str> = self.variants.iter().map(|v| v.id()).collect();
80+
Err(FoundryLocalError::ModelOperation {
81+
reason: format!(
82+
"Variant '{id}' not found for model '{}'. Available: {available:?}",
83+
self.alias
84+
),
85+
})
86+
}
7687
}
77-
let available: Vec<String> = self.variants.iter().map(|v| v.id().to_string()).collect();
78-
Err(FoundryLocalError::ModelOperation {
79-
reason: format!(
80-
"Variant '{id}' not found for model '{}'. Available: {available:?}",
81-
self.alias
82-
),
83-
})
8488
}
8589

8690
/// Returns a reference to the currently selected variant.
@@ -89,7 +93,7 @@ impl Model {
8993
}
9094

9195
/// Returns all variants that belong to this model.
92-
pub fn variants(&self) -> &[ModelVariant] {
96+
pub fn variants(&self) -> &[Arc<ModelVariant>] {
9397
&self.variants
9498
}
9599

@@ -169,11 +173,11 @@ impl Model {
169173

170174
/// Create a [`ChatClient`] bound to the selected variant.
171175
pub fn create_chat_client(&self) -> ChatClient {
172-
ChatClient::new(self.id().to_string(), Arc::clone(&self.core))
176+
ChatClient::new(self.id(), Arc::clone(&self.core))
173177
}
174178

175179
/// Create an [`AudioClient`] bound to the selected variant.
176180
pub fn create_audio_client(&self) -> AudioClient {
177-
AudioClient::new(self.id().to_string(), Arc::clone(&self.core))
181+
AudioClient::new(self.id(), Arc::clone(&self.core))
178182
}
179183
}

sdk/rust/src/model_variant.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,11 @@ impl ModelVariant {
143143

144144
/// Create a [`ChatClient`] bound to this variant.
145145
pub fn create_chat_client(&self) -> ChatClient {
146-
ChatClient::new(self.info.id.clone(), Arc::clone(&self.core))
146+
ChatClient::new(&self.info.id, Arc::clone(&self.core))
147147
}
148148

149149
/// Create an [`AudioClient`] bound to this variant.
150150
pub fn create_audio_client(&self) -> AudioClient {
151-
AudioClient::new(self.info.id.clone(), Arc::clone(&self.core))
151+
AudioClient::new(&self.info.id, Arc::clone(&self.core))
152152
}
153153
}

sdk/rust/src/openai/audio_client.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ pub struct AudioClient {
116116
}
117117

118118
impl AudioClient {
119-
pub(crate) fn new(model_id: String, core: Arc<CoreInterop>) -> Self {
119+
pub(crate) fn new(model_id: &str, core: Arc<CoreInterop>) -> Self {
120120
Self {
121-
model_id,
121+
model_id: model_id.to_owned(),
122122
core,
123123
settings: AudioClientSettings::default(),
124124
}

sdk/rust/src/openai/chat_client.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,9 @@ pub struct ChatClient {
132132
}
133133

134134
impl ChatClient {
135-
pub(crate) fn new(model_id: String, core: Arc<CoreInterop>) -> Self {
135+
pub(crate) fn new(model_id: &str, core: Arc<CoreInterop>) -> Self {
136136
Self {
137-
model_id,
137+
model_id: model_id.to_owned(),
138138
core,
139139
settings: ChatClientSettings::default(),
140140
}

0 commit comments

Comments
 (0)