diff --git a/src/llm/mod.rs b/src/llm/mod.rs index aaa412abd..8e8190f53 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -1,5 +1,7 @@ use std::borrow::Cow; +use anyhow::Result; +use async_trait::async_trait; use schemars::schema::SchemaObject; use serde::{Deserialize, Serialize}; @@ -32,5 +34,21 @@ pub struct LlmGenerateResponse { pub text: String, } -mod client; -pub use client::LlmClient; +#[async_trait] +pub trait LlmGenerationClient: Send + Sync { + async fn generate<'req>( + &self, + request: LlmGenerateRequest<'req>, + ) -> Result; +} + +mod ollama; + +pub async fn new_llm_generation_client(spec: LlmSpec) -> Result> { + let client = match spec.api_type { + LlmApiType::Ollama => { + Box::new(ollama::Client::new(spec).await?) as Box + } + }; + Ok(client) +} diff --git a/src/llm/client.rs b/src/llm/ollama.rs similarity index 89% rename from src/llm/client.rs rename to src/llm/ollama.rs index 28ada294f..31bb22447 100644 --- a/src/llm/client.rs +++ b/src/llm/ollama.rs @@ -1,8 +1,10 @@ +use super::LlmGenerationClient; use anyhow::Result; +use async_trait::async_trait; use schemars::schema::SchemaObject; use serde::{Deserialize, Serialize}; -pub struct LlmClient { +pub struct Client { generate_url: String, model: String, reqwest_client: reqwest::Client, @@ -30,7 +32,7 @@ struct OllamaResponse { const OLLAMA_DEFAULT_ADDRESS: &str = "http://localhost:11434"; -impl LlmClient { +impl Client { pub async fn new(spec: super::LlmSpec) -> Result { let address = match &spec.address { Some(addr) => addr.trim_end_matches('/'), @@ -42,10 +44,13 @@ impl LlmClient { reqwest_client: reqwest::Client::new(), }) } +} - pub async fn generate<'a>( +#[async_trait] +impl LlmGenerationClient for Client { + async fn generate<'req>( &self, - request: super::LlmGenerateRequest<'a>, + request: super::LlmGenerateRequest<'req>, ) -> Result { let req = OllamaRequest { model: &self.model, diff --git a/src/ops/functions/extract_by_llm.rs b/src/ops/functions/extract_by_llm.rs index 3b3989dbc..64363ebc1 100644 --- a/src/ops/functions/extract_by_llm.rs +++ b/src/ops/functions/extract_by_llm.rs @@ -5,7 +5,9 @@ use schemars::schema::SchemaObject; use serde::Serialize; use crate::base::json_schema::ToJsonSchema; -use crate::llm::{LlmClient, LlmGenerateRequest, LlmSpec, OutputFormat}; +use crate::llm::{ + new_llm_generation_client, LlmGenerateRequest, LlmGenerationClient, LlmSpec, OutputFormat, +}; use crate::ops::sdk::*; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -16,7 +18,7 @@ pub struct Spec { } struct Executor { - client: LlmClient, + client: Box, output_json_schema: SchemaObject, output_type: EnrichedValueType, system_prompt: String, @@ -41,7 +43,7 @@ Output only the JSON without any additional messages or explanations." impl Executor { async fn new(spec: Spec) -> Result { Ok(Self { - client: LlmClient::new(spec.llm_spec).await?, + client: new_llm_generation_client(spec.llm_spec).await?, output_json_schema: spec.output_type.to_json_schema(), output_type: spec.output_type, system_prompt: get_system_prompt(&spec.instruction),