|
| 1 | +use async_trait::async_trait; |
| 2 | +use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat}; |
| 3 | +use anyhow::{Result, bail, Context}; |
| 4 | +use serde_json::Value; |
| 5 | +use json5; |
| 6 | + |
| 7 | +use crate::api_bail; |
| 8 | +use urlencoding::encode; |
| 9 | + |
| 10 | +pub struct Client { |
| 11 | + model: String, |
| 12 | + api_key: String, |
| 13 | + client: reqwest::Client, |
| 14 | +} |
| 15 | + |
| 16 | +impl Client { |
| 17 | + pub async fn new(spec: LlmSpec) -> Result<Self> { |
| 18 | + let api_key = match std::env::var("ANTHROPIC_API_KEY") { |
| 19 | + Ok(val) => val, |
| 20 | + Err(_) => api_bail!("ANTHROPIC_API_KEY environment variable must be set"), |
| 21 | + }; |
| 22 | + Ok(Self { |
| 23 | + model: spec.model, |
| 24 | + api_key, |
| 25 | + client: reqwest::Client::new(), |
| 26 | + }) |
| 27 | + } |
| 28 | +} |
| 29 | + |
| 30 | +#[async_trait] |
| 31 | +impl LlmGenerationClient for Client { |
| 32 | + async fn generate<'req>( |
| 33 | + &self, |
| 34 | + request: LlmGenerateRequest<'req>, |
| 35 | + ) -> Result<LlmGenerateResponse> { |
| 36 | + let messages = vec![serde_json::json!({ |
| 37 | + "role": "user", |
| 38 | + "content": request.user_prompt |
| 39 | + })]; |
| 40 | + |
| 41 | + let mut payload = serde_json::json!({ |
| 42 | + "model": self.model, |
| 43 | + "messages": messages, |
| 44 | + "max_tokens": 4096 |
| 45 | + }); |
| 46 | + |
| 47 | + // Add system prompt as top-level field if present (required) |
| 48 | + if let Some(system) = request.system_prompt { |
| 49 | + payload["system"] = serde_json::json!(system); |
| 50 | + } |
| 51 | + |
| 52 | + // Extract schema from output_format, error if not JsonSchema |
| 53 | + let schema = match request.output_format.as_ref() { |
| 54 | + Some(OutputFormat::JsonSchema { schema, .. }) => schema, |
| 55 | + _ => api_bail!("Anthropic client expects OutputFormat::JsonSchema for all requests"), |
| 56 | + }; |
| 57 | + |
| 58 | + let schema_json = serde_json::to_value(schema)?; |
| 59 | + payload["tools"] = serde_json::json!([ |
| 60 | + { "type": "custom", "name": "report_result", "input_schema": schema_json } |
| 61 | + ]); |
| 62 | + |
| 63 | + let url = "https://api.anthropic.com/v1/messages"; |
| 64 | + |
| 65 | + let encoded_api_key = encode(&self.api_key); |
| 66 | + |
| 67 | + let resp = self.client |
| 68 | + .post(url) |
| 69 | + .header("x-api-key", encoded_api_key.as_ref()) |
| 70 | + .header("anthropic-version", "2023-06-01") |
| 71 | + .json(&payload) |
| 72 | + .send() |
| 73 | + .await |
| 74 | + .context("HTTP error")?; |
| 75 | + let mut resp_json: Value = resp.json().await.context("Invalid JSON")?; |
| 76 | + if let Some(error) = resp_json.get("error") { |
| 77 | + bail!("Anthropic API error: {:?}", error); |
| 78 | + } |
| 79 | + |
| 80 | + // Debug print full response |
| 81 | + // println!("Anthropic API full response: {resp_json:?}"); |
| 82 | + |
| 83 | + let resp_content = &resp_json["content"]; |
| 84 | + let tool_name = "report_result"; |
| 85 | + let mut extracted_json: Option<Value> = None; |
| 86 | + if let Some(array) = resp_content.as_array() { |
| 87 | + for item in array { |
| 88 | + if item.get("type") == Some(&Value::String("tool_use".to_string())) |
| 89 | + && item.get("name") == Some(&Value::String(tool_name.to_string())) |
| 90 | + { |
| 91 | + if let Some(input) = item.get("input") { |
| 92 | + extracted_json = Some(input.clone()); |
| 93 | + break; |
| 94 | + } |
| 95 | + } |
| 96 | + } |
| 97 | + } |
| 98 | + let text = if let Some(json) = extracted_json { |
| 99 | + // Try strict JSON serialization first |
| 100 | + serde_json::to_string(&json)? |
| 101 | + } else { |
| 102 | + // Fallback: try text if no tool output found |
| 103 | + match &mut resp_json["content"][0]["text"] { |
| 104 | + Value::String(s) => { |
| 105 | + // Try strict JSON parsing first |
| 106 | + match serde_json::from_str::<serde_json::Value>(s) { |
| 107 | + Ok(_) => std::mem::take(s), |
| 108 | + Err(e) => { |
| 109 | + // Try permissive json5 parsing as fallback |
| 110 | + match json5::from_str::<serde_json::Value>(s) { |
| 111 | + Ok(value) => { |
| 112 | + println!("[Anthropic] Used permissive JSON5 parser for output"); |
| 113 | + serde_json::to_string(&value)? |
| 114 | + }, |
| 115 | + Err(e2) => return Err(anyhow::anyhow!(format!("No structured tool output or text found in response, and permissive JSON5 parsing also failed: {e}; {e2}"))) |
| 116 | + } |
| 117 | + } |
| 118 | + } |
| 119 | + }, |
| 120 | + _ => return Err(anyhow::anyhow!("No structured tool output or text found in response")), |
| 121 | + } |
| 122 | + }; |
| 123 | + |
| 124 | + Ok(LlmGenerateResponse { |
| 125 | + text, |
| 126 | + }) |
| 127 | + } |
| 128 | + |
| 129 | + fn json_schema_options(&self) -> ToJsonSchemaOptions { |
| 130 | + ToJsonSchemaOptions { |
| 131 | + fields_always_required: false, |
| 132 | + supports_format: false, |
| 133 | + extract_descriptions: false, |
| 134 | + top_level_must_be_object: true, |
| 135 | + } |
| 136 | + } |
| 137 | +} |
0 commit comments