Skip to content

Commit d1c5068

Browse files
committed
Refactor Ollama client logic to a trait - make next integration easier
1 parent 6b44cee commit d1c5068

File tree

3 files changed

+34
-9
lines changed

3 files changed

+34
-9
lines changed

src/llm/mod.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use std::borrow::Cow;
22

3+
use anyhow::Result;
4+
use async_trait::async_trait;
35
use schemars::schema::SchemaObject;
46
use serde::{Deserialize, Serialize};
57

@@ -32,5 +34,21 @@ pub struct LlmGenerateResponse {
3234
pub text: String,
3335
}
3436

35-
mod client;
36-
pub use client::LlmClient;
37+
#[async_trait]
38+
pub trait LlmGenerationClient: Send + Sync {
39+
async fn generate<'req>(
40+
&self,
41+
request: LlmGenerateRequest<'req>,
42+
) -> Result<LlmGenerateResponse>;
43+
}
44+
45+
mod ollama;
46+
47+
pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGenerationClient>> {
48+
let client = match spec.api_type {
49+
LlmApiType::Ollama => {
50+
Box::new(ollama::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
51+
}
52+
};
53+
Ok(client)
54+
}

src/llm/client.rs renamed to src/llm/ollama.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
use super::LlmGenerationClient;
12
use anyhow::Result;
3+
use async_trait::async_trait;
24
use schemars::schema::SchemaObject;
35
use serde::{Deserialize, Serialize};
46

5-
pub struct LlmClient {
7+
pub struct Client {
68
generate_url: String,
79
model: String,
810
reqwest_client: reqwest::Client,
@@ -30,7 +32,7 @@ struct OllamaResponse {
3032

3133
const OLLAMA_DEFAULT_ADDRESS: &str = "http://localhost:11434";
3234

33-
impl LlmClient {
35+
impl Client {
3436
pub async fn new(spec: super::LlmSpec) -> Result<Self> {
3537
let address = match &spec.address {
3638
Some(addr) => addr.trim_end_matches('/'),
@@ -42,10 +44,13 @@ impl LlmClient {
4244
reqwest_client: reqwest::Client::new(),
4345
})
4446
}
47+
}
4548

46-
pub async fn generate<'a>(
49+
#[async_trait]
50+
impl LlmGenerationClient for Client {
51+
async fn generate<'req>(
4752
&self,
48-
request: super::LlmGenerateRequest<'a>,
53+
request: super::LlmGenerateRequest<'req>,
4954
) -> Result<super::LlmGenerateResponse> {
5055
let req = OllamaRequest {
5156
model: &self.model,

src/ops/functions/extract_by_llm.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ use schemars::schema::SchemaObject;
55
use serde::Serialize;
66

77
use crate::base::json_schema::ToJsonSchema;
8-
use crate::llm::{LlmClient, LlmGenerateRequest, LlmSpec, OutputFormat};
8+
use crate::llm::{
9+
new_llm_generation_client, LlmGenerateRequest, LlmGenerationClient, LlmSpec, OutputFormat,
10+
};
911
use crate::ops::sdk::*;
1012

1113
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -16,7 +18,7 @@ pub struct Spec {
1618
}
1719

1820
struct Executor {
19-
client: LlmClient,
21+
client: Box<dyn LlmGenerationClient>,
2022
output_json_schema: SchemaObject,
2123
output_type: EnrichedValueType,
2224
system_prompt: String,
@@ -41,7 +43,7 @@ Output only the JSON without any additional messages or explanations."
4143
impl Executor {
4244
async fn new(spec: Spec) -> Result<Self> {
4345
Ok(Self {
44-
client: LlmClient::new(spec.llm_spec).await?,
46+
client: new_llm_generation_client(spec.llm_spec).await?,
4547
output_json_schema: spec.output_type.to_json_schema(),
4648
output_type: spec.output_type,
4749
system_prompt: get_system_prompt(&spec.instruction),

0 commit comments

Comments
 (0)