|
| 1 | +use async_trait::async_trait; |
| 2 | +use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat}; |
| 3 | +use anyhow::{Result, bail, Context}; |
| 4 | +use serde_json::Value; |
| 5 | +use crate::api_bail; |
| 6 | +use urlencoding::encode; |
| 7 | + |
| 8 | +pub struct Client { |
| 9 | + model: String, |
| 10 | + api_key: String, |
| 11 | + client: reqwest::Client, |
| 12 | +} |
| 13 | + |
| 14 | +impl Client { |
| 15 | + pub async fn new(spec: LlmSpec) -> Result<Self> { |
| 16 | + let api_key = match std::env::var("ANTHROPIC_API_KEY") { |
| 17 | + Ok(val) => val, |
| 18 | + Err(_) => api_bail!("ANTHROPIC_API_KEY environment variable must be set"), |
| 19 | + }; |
| 20 | + Ok(Self { |
| 21 | + model: spec.model, |
| 22 | + api_key, |
| 23 | + client: reqwest::Client::new(), |
| 24 | + }) |
| 25 | + } |
| 26 | +} |
| 27 | + |
| 28 | +#[async_trait] |
| 29 | +impl LlmGenerationClient for Client { |
| 30 | + async fn generate<'req>( |
| 31 | + &self, |
| 32 | + request: LlmGenerateRequest<'req>, |
| 33 | + ) -> Result<LlmGenerateResponse> { |
| 34 | + // Compose the prompt/messages |
| 35 | + let mut messages = vec![serde_json::json!({ |
| 36 | + "role": "user", |
| 37 | + "content": request.user_prompt |
| 38 | + })]; |
| 39 | + if let Some(system) = request.system_prompt { |
| 40 | + messages.insert(0, serde_json::json!({ |
| 41 | + "role": "system", |
| 42 | + "content": system |
| 43 | + })); |
| 44 | + } |
| 45 | + |
| 46 | + let mut payload = serde_json::json!({ |
| 47 | + "model": self.model, |
| 48 | + "messages": messages, |
| 49 | + "max_tokens": 4096 |
| 50 | + }); |
| 51 | + |
| 52 | + // If structured output is requested, add schema |
| 53 | + if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format { |
| 54 | + let schema_json = serde_json::to_value(schema)?; |
| 55 | + payload["tools"] = serde_json::json!([ |
| 56 | + { "type": "json_object", "parameters": schema_json } |
| 57 | + ]); |
| 58 | + } |
| 59 | + |
| 60 | + let url = "https://api.anthropic.com/v1/messages"; |
| 61 | + |
| 62 | + let encoded_api_key = encode(&self.api_key); |
| 63 | + let resp = self.client |
| 64 | + .post(url) |
| 65 | + .header("x-api-key", encoded_api_key.as_ref()) |
| 66 | + .json(&payload) |
| 67 | + .send() |
| 68 | + .await |
| 69 | + .context("HTTP error")?; |
| 70 | + |
| 71 | + let resp_json: Value = resp.json().await.context("Invalid JSON")?; |
| 72 | + |
| 73 | + if let Some(error) = resp_json.get("error") { |
| 74 | + bail!("Anthropic API error: {:?}", error); |
| 75 | + } |
| 76 | + let mut resp_json = resp_json; |
| 77 | + let text = match &mut resp_json["content"][0]["text"] { |
| 78 | + Value::String(s) => std::mem::take(s), |
| 79 | + _ => bail!("No text in response"), |
| 80 | + }; |
| 81 | + |
| 82 | + Ok(LlmGenerateResponse { |
| 83 | + text, |
| 84 | + }) |
| 85 | + } |
| 86 | + |
| 87 | + fn json_schema_options(&self) -> ToJsonSchemaOptions { |
| 88 | + ToJsonSchemaOptions { |
| 89 | + fields_always_required: false, |
| 90 | + supports_format: false, |
| 91 | + extract_descriptions: false, |
| 92 | + top_level_must_be_object: true, |
| 93 | + } |
| 94 | + } |
| 95 | +} |
0 commit comments