Skip to content

Commit 79b09d2

Browse files
authored
Merge pull request #203 from firstbatchxyz/erhant/model-changes
feat: use ollama models
2 parents 3ab82cd + 4d82772 commit 79b09d2

File tree

20 files changed

+193
-272
lines changed

20 files changed

+193
-272
lines changed

Cargo.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ default-members = ["compute"]
77

88
[workspace.package]
99
edition = "2021"
10-
version = "0.6.5"
10+
version = "0.6.6"
1111
license = "Apache-2.0"
1212
readme = "README.md"
1313

compute/src/config.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ impl DriaComputeNodeConfig {
6363
}
6464
}
6565
Err(err) => {
66-
log::error!("No secret key provided: {}", err);
66+
log::error!("No secret key provided: {err}");
6767
panic!("Please provide a secret key.");
6868
}
6969
};
@@ -81,11 +81,11 @@ impl DriaComputeNodeConfig {
8181

8282
// print address
8383
let address = hex::encode(public_key_to_address(&public_key));
84-
log::info!("Node Address: 0x{}", address);
84+
log::info!("Node Address: 0x{address}");
8585

8686
// to this here to log the peer id at start
8787
let peer_id = secret_to_keypair(&secret_key).public().to_peer_id();
88-
log::info!("Node PeerID: {}", peer_id);
88+
log::info!("Node PeerID: {peer_id}");
8989

9090
// parse listen address
9191
let p2p_listen_addr_str = env::var("DKN_P2P_LISTEN_ADDR")

compute/src/main.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ async fn main() -> Result<()> {
3636

3737
// log about env usage
3838
match dotenv_result {
39-
Ok(_) => log::info!("Loaded environment file from {}", env_path),
40-
Err(e) => log::warn!("Could not load environment file from {}: {}", env_path, e),
39+
Ok(_) => log::info!("Loaded environment file from {env_path}"),
40+
Err(err) => log::warn!("Could not load environment file from {env_path}: {err}"),
4141
}
4242

4343
// task tracker for multiple threads
@@ -52,14 +52,14 @@ async fn main() -> Result<()> {
5252
env::var("DKN_EXIT_TIMEOUT").map(|s| s.to_string().parse::<u64>())
5353
{
5454
// the timeout is done for profiling only, and should not be used in production
55-
log::warn!("Waiting for {} seconds before exiting.", duration_secs);
55+
log::warn!("Waiting for {duration_secs} seconds before exiting.");
5656
tokio::time::sleep(tokio::time::Duration::from_secs(duration_secs)).await;
5757

5858
log::warn!("Exiting due to DKN_EXIT_TIMEOUT.");
5959
cancellation_token.cancel();
6060
} else if let Err(err) = wait_for_termination(cancellation_token.clone()).await {
6161
// if there is no timeout, we wait for termination signals here
62-
log::error!("Error waiting for termination: {:?}", err);
62+
log::error!("Error waiting for termination: {err:?}");
6363
log::error!("Cancelling due to unexpected error.");
6464
cancellation_token.cancel();
6565
};
@@ -104,7 +104,7 @@ async fn main() -> Result<()> {
104104
config.executors.get_model_names().join(", "),
105105
model_perf
106106
.iter()
107-
.map(|(model, perf)| format!("{}: {}", model, perf))
107+
.map(|(model, perf)| format!("{model}: {perf}"))
108108
.collect::<Vec<_>>()
109109
.join("\n")
110110
);
@@ -124,10 +124,7 @@ async fn main() -> Result<()> {
124124
batch_size <= TaskWorker::MAX_BATCH_SIZE,
125125
"batch size too large"
126126
);
127-
log::info!(
128-
"Spawning batch executor worker thread. (batch size {})",
129-
batch_size
130-
);
127+
log::info!("Spawning batch executor worker thread. (batch size {batch_size})");
131128
task_tracker.spawn(async move { worker_batch.run_batch(batch_size).await });
132129
}
133130

compute/src/node/core.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ impl DriaComputeNode {
5353
// a task is completed by the worker & should be responded to the requesting peer
5454
task_response_msg_opt = self.task_output_rx.recv() => {
5555
if let Some(task_response_msg) = task_response_msg_opt {
56-
if let Err(e) = self.send_task_output(task_response_msg).await {
57-
log::error!("Error responding to task: {:?}", e);
56+
if let Err(err) = self.send_task_output(task_response_msg).await {
57+
log::error!("Error responding to task: {err:?}");
5858
}
5959
} else {
6060
log::error!("task_output_rx channel closed unexpectedly, we still have {} batch and {} single tasks.", self.pending_tasks_batch.len(), self.pending_tasks_single.len());
@@ -117,8 +117,8 @@ impl DriaComputeNode {
117117
self.handle_diagnostic_refresh().await;
118118

119119
// shutdown channels
120-
if let Err(e) = self.shutdown().await {
121-
log::error!("Could not shutdown the node gracefully: {:?}", e);
120+
if let Err(err) = self.shutdown().await {
121+
log::error!("Could not shutdown the node gracefully: {err:?}");
122122
}
123123
}
124124

compute/src/node/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ impl DriaComputeNode {
8181

8282
// dial the RPC node
8383
let dria_rpc = if let Some(addr) = config.initial_rpc_addr.take() {
84-
log::info!("Using initial RPC address: {}", addr);
84+
log::info!("Using initial RPC address: {addr}");
8585
DriaRPC::new(addr, config.network).expect("could not get RPC to connect to")
8686
} else {
8787
DriaRPC::new_for_network(config.network, &config.version)
@@ -92,7 +92,7 @@ impl DriaComputeNode {
9292
// we are using the major.minor version as the P2P version
9393
// so that patch versions do not interfere with the protocol
9494
let protocol = DriaP2PProtocol::new_major_minor(config.network.protocol_name());
95-
log::info!("Using identity: {}", protocol);
95+
log::info!("Using identity: {protocol}");
9696

9797
// create p2p client
9898
let (p2p_client, p2p_commander, request_rx) = DriaP2PClient::new(

compute/src/node/reqres.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,24 @@ impl DriaComputeNode {
3030
request_id,
3131
channel,
3232
} => {
33-
log::debug!("Received a request ({}) from {}", request_id, peer_id);
33+
log::debug!("Received a request ({request_id}) from {peer_id}");
3434

3535
// ensure that message is from the known RPCs
3636
if self.dria_rpc.peer_id != peer_id {
37-
log::warn!("Received request from unauthorized source: {}", peer_id);
37+
log::warn!("Received request from unauthorized source: {peer_id}");
3838
log::debug!("Allowed source: {}", self.dria_rpc.peer_id);
39-
} else if let Err(e) = self.handle_request(peer_id, &request, channel).await {
40-
log::error!("Error handling request: {:?}", e);
39+
} else if let Err(err) = self.handle_request(peer_id, &request, channel).await {
40+
log::error!("Error handling request: {err:?}");
4141
}
4242
}
4343

4444
DriaReqResMessage::Response {
4545
response,
4646
request_id,
4747
} => {
48-
log::debug!("Received a response ({}) from {}", request_id, peer_id);
49-
if let Err(e) = self.handle_response(peer_id, request_id, response).await {
50-
log::error!("Error handling response: {:?}", e);
48+
log::debug!("Received a response ({request_id}) from {peer_id}");
49+
if let Err(err) = self.handle_response(peer_id, request_id, response).await {
50+
log::error!("Error handling response: {err:?}");
5151
}
5252
}
5353
};
@@ -65,7 +65,7 @@ impl DriaComputeNode {
6565
data: Vec<u8>,
6666
) -> Result<()> {
6767
if peer_id != self.dria_rpc.peer_id {
68-
log::warn!("Received response from unauthorized source: {}", peer_id);
68+
log::warn!("Received response from unauthorized source: {peer_id}");
6969
log::debug!("Allowed source: {}", self.dria_rpc.peer_id);
7070
}
7171

@@ -126,7 +126,7 @@ impl DriaComputeNode {
126126

127127
let (task_input, task_metadata) =
128128
TaskResponder::parse_task_request(self, &task_request, channel).await?;
129-
if let Err(e) = match task_input.task.is_batchable() {
129+
if let Err(err) = match task_input.task.is_batchable() {
130130
// this is a batchable task, send it to batch worker
131131
// and keep track of the task id in pending tasks
132132
true => match self.task_request_batch_tx {
@@ -149,7 +149,7 @@ impl DriaComputeNode {
149149
None => eyre::bail!("Single task received but no worker available."),
150150
},
151151
} {
152-
log::error!("Could not send task to worker: {:?}", e);
152+
log::error!("Could not send task to worker: {err:?}");
153153
};
154154

155155
Ok(())

compute/src/reqres/task.rs

Lines changed: 54 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -175,63 +175,63 @@ fn map_prompt_error_to_task_error(provider: ModelProvider, err: PromptError) ->
175175
}
176176

177177
match provider {
178-
ModelProvider::Gemini => {
179-
/// Gemini API [error object](https://github.com/googleapis/go-genai/blob/main/api_client.go#L273).
180-
#[derive(Clone, serde::Deserialize)]
181-
pub struct GeminiError {
182-
code: u32,
183-
message: String,
184-
status: String,
185-
}
178+
// ModelProvider::Gemini => {
179+
// /// Gemini API [error object](https://github.com/googleapis/go-genai/blob/main/api_client.go#L273).
180+
// #[derive(Clone, serde::Deserialize)]
181+
// pub struct GeminiError {
182+
// code: u32,
183+
// message: String,
184+
// status: String,
185+
// }
186186

187-
serde_json::from_str::<ErrorObject<GeminiError>>(err_inner).map(
188-
|ErrorObject {
189-
error: gemini_error,
190-
}| TaskError::ProviderError {
191-
code: format!("{} ({})", gemini_error.code, gemini_error.status),
192-
message: gemini_error.message,
193-
provider: provider.to_string(),
194-
},
195-
)
196-
}
197-
ModelProvider::OpenAI => {
198-
/// OpenAI API [error object](https://github.com/openai/openai-go/blob/main/internal/apierror/apierror.go#L17).
199-
#[derive(Clone, serde::Deserialize)]
200-
pub struct OpenAIError {
201-
code: String,
202-
message: String,
203-
}
187+
// serde_json::from_str::<ErrorObject<GeminiError>>(err_inner).map(
188+
// |ErrorObject {
189+
// error: gemini_error,
190+
// }| TaskError::ProviderError {
191+
// code: format!("{} ({})", gemini_error.code, gemini_error.status),
192+
// message: gemini_error.message,
193+
// provider: provider.to_string(),
194+
// },
195+
// )
196+
// }
197+
// ModelProvider::OpenAI => {
198+
// /// OpenAI API [error object](https://github.com/openai/openai-go/blob/main/internal/apierror/apierror.go#L17).
199+
// #[derive(Clone, serde::Deserialize)]
200+
// pub struct OpenAIError {
201+
// code: String,
202+
// message: String,
203+
// }
204204

205-
serde_json::from_str::<ErrorObject<OpenAIError>>(err_inner).map(
206-
|ErrorObject {
207-
error: openai_error,
208-
}| TaskError::ProviderError {
209-
code: openai_error.code,
210-
message: openai_error.message,
211-
provider: provider.to_string(),
212-
},
213-
)
214-
}
215-
ModelProvider::OpenRouter => {
216-
/// OpenRouter API [error object](https://openrouter.ai/docs/api-reference/errors).
217-
#[derive(Clone, serde::Deserialize)]
218-
pub struct OpenRouterError {
219-
code: u32,
220-
message: String,
221-
}
205+
// serde_json::from_str::<ErrorObject<OpenAIError>>(err_inner).map(
206+
// |ErrorObject {
207+
// error: openai_error,
208+
// }| TaskError::ProviderError {
209+
// code: openai_error.code,
210+
// message: openai_error.message,
211+
// provider: provider.to_string(),
212+
// },
213+
// )
214+
// }
215+
// ModelProvider::OpenRouter => {
216+
// /// OpenRouter API [error object](https://openrouter.ai/docs/api-reference/errors).
217+
// #[derive(Clone, serde::Deserialize)]
218+
// pub struct OpenRouterError {
219+
// code: u32,
220+
// message: String,
221+
// }
222222

223-
serde_json::from_str::<ErrorObject<OpenRouterError>>(err_inner).map(
224-
|ErrorObject {
225-
error: openrouter_error,
226-
}| {
227-
TaskError::ProviderError {
228-
code: openrouter_error.code.to_string(),
229-
message: openrouter_error.message,
230-
provider: provider.to_string(),
231-
}
232-
},
233-
)
234-
}
223+
// serde_json::from_str::<ErrorObject<OpenRouterError>>(err_inner).map(
224+
// |ErrorObject {
225+
// error: openrouter_error,
226+
// }| {
227+
// TaskError::ProviderError {
228+
// code: openrouter_error.code.to_string(),
229+
// message: openrouter_error.message,
230+
// provider: provider.to_string(),
231+
// }
232+
// },
233+
// )
234+
// }
235235
ModelProvider::Ollama => serde_json::from_str::<ErrorObject<String>>(err_inner)
236236
.map(
237237
// Ollama just returns a string error message

compute/src/utils/specs.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ mod tests {
9191
vec![Model::Gemma3_4b.to_string()],
9292
HashMap::from_iter([
9393
(Model::Gemma3_4b, SpecModelPerformance::PassedWithTPS(100.0)),
94-
(Model::GPT4oMini, SpecModelPerformance::NotFound),
9594
(Model::Gemma3_27b, SpecModelPerformance::ExecutionFailed),
9695
]),
9796
SemanticVersion {

compute/src/workers/task.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ impl TaskWorker {
139139
);
140140
debug_assert!(num_tasks != 0, "number of tasks cant be zero");
141141

142-
log::info!("Processing {} tasks in batch", num_tasks);
142+
log::info!("Processing {num_tasks} tasks in batch");
143143
let mut batch = tasks.into_iter().map(|b| (b, &self.publish_tx));
144144
match num_tasks {
145145
1 => {
@@ -235,8 +235,8 @@ impl TaskWorker {
235235
stats: input.stats,
236236
};
237237

238-
if let Err(e) = publish_tx.send(output).await {
239-
log::error!("Error sending task result: {}", e);
238+
if let Err(err) = publish_tx.send(output).await {
239+
log::error!("Error sending task result: {err}");
240240
}
241241
}
242242
}
@@ -254,7 +254,7 @@ mod tests {
254254
/// cargo test --package dkn-compute --lib --all-features -- workers::task::tests::test_executor_worker --exact --show-output --nocapture --ignored
255255
/// ```
256256
#[tokio::test]
257-
#[ignore = "run manually"]
257+
#[ignore = "run manually with Ollama"]
258258
async fn test_executor_worker() {
259259
let _ = env_logger::builder()
260260
.filter_level(log::LevelFilter::Off)
@@ -271,7 +271,7 @@ mod tests {
271271
});
272272

273273
let num_tasks = 4;
274-
let model = Model::GPT4o;
274+
let model = Model::Llama3_2_1bInstructQ4Km;
275275
let executor = DriaExecutor::new_from_env(model.provider()).unwrap();
276276
let task = TaskBody::new_prompt("Write a poem about Julius Caesar.", model.clone());
277277

0 commit comments

Comments
 (0)