From 6720b369bc2d7d8bcf8446cef2adb6732b47c211 Mon Sep 17 00:00:00 2001 From: Param Arora Date: Wed, 28 May 2025 12:13:27 +0530 Subject: [PATCH 1/2] feat: llm generate json response --- src/llm/anthropic.rs | 76 +++++++---------------------- src/llm/gemini.rs | 42 +++++++++------- src/llm/mod.rs | 8 +-- src/llm/ollama.rs | 19 ++++++-- src/llm/openai.rs | 16 ++++-- src/llm/prompt_utils.rs | 4 ++ src/ops/functions/extract_by_llm.rs | 9 ++-- 7 files changed, 85 insertions(+), 89 deletions(-) create mode 100644 src/llm/prompt_utils.rs diff --git a/src/llm/anthropic.rs b/src/llm/anthropic.rs index 1001f9083..3d770e5f4 100644 --- a/src/llm/anthropic.rs +++ b/src/llm/anthropic.rs @@ -1,10 +1,7 @@ -use crate::llm::{ - LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, LlmSpec, OutputFormat, - ToJsonSchemaOptions, -}; -use anyhow::{bail, Context, Result}; use async_trait::async_trait; -use json5; +use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat}; +use anyhow::{Result, bail, Context}; +use crate::llm::prompt_utils::STRICT_JSON_PROMPT; use serde_json::Value; use crate::api_bail; @@ -48,9 +45,11 @@ impl LlmGenerationClient for Client { }); // Add system prompt as top-level field if present (required) - if let Some(system) = request.system_prompt { - payload["system"] = serde_json::json!(system); + let mut system_prompt = request.system_prompt.unwrap_or_default(); + if matches!(request.output_format, Some(OutputFormat::JsonSchema { .. })) { + system_prompt = format!("{STRICT_JSON_PROMPT}\n\n{system_prompt}").into(); } + payload["system"] = serde_json::json!(system_prompt); // Extract schema from output_format, error if not JsonSchema let schema = match request.output_format.as_ref() { @@ -67,8 +66,7 @@ impl LlmGenerationClient for Client { let encoded_api_key = encode(&self.api_key); - let resp = self - .client + let resp = self.client .post(url) .header("x-api-key", encoded_api_key.as_ref()) .header("anthropic-version", "2023-06-01") @@ -76,60 +74,22 @@ impl LlmGenerationClient for Client { .send() .await .context("HTTP error")?; - let mut resp_json: Value = resp.json().await.context("Invalid JSON")?; + let resp_json: Value = resp.json().await.context("Invalid JSON")?; if let Some(error) = resp_json.get("error") { bail!("Anthropic API error: {:?}", error); } - // Debug print full response - // println!("Anthropic API full response: {resp_json:?}"); - - let resp_content = &resp_json["content"]; - let tool_name = "report_result"; - let mut extracted_json: Option = None; - if let Some(array) = resp_content.as_array() { - for item in array { - if item.get("type") == Some(&Value::String("tool_use".to_string())) - && item.get("name") == Some(&Value::String(tool_name.to_string())) - { - if let Some(input) = item.get("input") { - extracted_json = Some(input.clone()); - break; - } - } - } - } - let text = if let Some(json) = extracted_json { - // Try strict JSON serialization first - serde_json::to_string(&json)? - } else { - // Fallback: try text if no tool output found - match &mut resp_json["content"][0]["text"] { - Value::String(s) => { - // Try strict JSON parsing first - match serde_json::from_str::(s) { - Ok(_) => std::mem::take(s), - Err(e) => { - // Try permissive json5 parsing as fallback - match json5::from_str::(s) { - Ok(value) => { - println!("[Anthropic] Used permissive JSON5 parser for output"); - serde_json::to_string(&value)? - }, - Err(e2) => return Err(anyhow::anyhow!(format!("No structured tool output or text found in response, and permissive JSON5 parsing also failed: {e}; {e2}"))) - } - } - } - } - _ => { - return Err(anyhow::anyhow!( - "No structured tool output or text found in response" - )) - } - } + // Extract the text response + let text = match resp_json["content"][0]["text"].as_str() { + Some(s) => s.to_string(), + None => bail!("No text in response"), }; - Ok(LlmGenerateResponse { text }) + // Try to parse as JSON + match serde_json::from_str::(&text) { + Ok(val) => Ok(LlmGenerateResponse::Json(val)), + Err(_) => Ok(LlmGenerateResponse::Text(text)), + } } fn json_schema_options(&self) -> ToJsonSchemaOptions { diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index 11c34ebb1..4a7600cb6 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -1,12 +1,10 @@ -use crate::api_bail; -use crate::llm::{ - LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, LlmSpec, OutputFormat, - ToJsonSchemaOptions, -}; -use anyhow::{bail, Context, Result}; use async_trait::async_trait; +use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat}; +use anyhow::{Result, bail, Context}; use serde_json::Value; +use crate::api_bail; use urlencoding::encode; +use crate::llm::prompt_utils::STRICT_JSON_PROMPT; pub struct Client { model: String, @@ -60,11 +58,14 @@ impl LlmGenerationClient for Client { // Prepare payload let mut payload = serde_json::json!({ "contents": contents }); - if let Some(system) = request.system_prompt { - payload["systemInstruction"] = serde_json::json!({ - "parts": [ { "text": system } ] - }); - } + if let Some(mut system) = request.system_prompt { + if matches!(request.output_format, Some(OutputFormat::JsonSchema { .. })) { + system = format!("{STRICT_JSON_PROMPT}\n\n{system}").into(); + } + payload["systemInstruction"] = serde_json::json!({ + "parts": [ { "text": system } ] + }); +} // If structured output is requested, add schema and responseMimeType if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format { @@ -79,13 +80,10 @@ impl LlmGenerationClient for Client { let api_key = &self.api_key; let url = format!( "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}", - encode(&self.model), - encode(api_key) + encode(&self.model), encode(api_key) ); - let resp = self - .client - .post(&url) + let resp = self.client.post(&url) .json(&payload) .send() .await @@ -102,7 +100,15 @@ impl LlmGenerationClient for Client { _ => bail!("No text in response"), }; - Ok(LlmGenerateResponse { text }) + // If output_format is JsonSchema, try to parse as JSON + if let Some(OutputFormat::JsonSchema { .. }) = request.output_format { + match serde_json::from_str::(&text) { + Ok(val) => Ok(LlmGenerateResponse::Json(val)), + Err(_) => Ok(LlmGenerateResponse::Text(text)), + } + } else { + Ok(LlmGenerateResponse::Text(text)) + } } fn json_schema_options(&self) -> ToJsonSchemaOptions { @@ -113,4 +119,4 @@ impl LlmGenerationClient for Client { top_level_must_be_object: true, } } -} +} \ No newline at end of file diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 5a2706aa3..b8865e8d5 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -22,7 +22,7 @@ pub struct LlmSpec { model: String, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum OutputFormat<'a> { JsonSchema { name: Cow<'a, str>, @@ -38,8 +38,9 @@ pub struct LlmGenerateRequest<'a> { } #[derive(Debug)] -pub struct LlmGenerateResponse { - pub text: String, +pub enum LlmGenerateResponse { + Json(serde_json::Value), + Text(String), } #[async_trait] @@ -56,6 +57,7 @@ mod anthropic; mod gemini; mod ollama; mod openai; +mod prompt_utils; pub async fn new_llm_generation_client(spec: LlmSpec) -> Result> { let client = match spec.api_type { diff --git a/src/llm/ollama.rs b/src/llm/ollama.rs index f29260775..afddaecfb 100644 --- a/src/llm/ollama.rs +++ b/src/llm/ollama.rs @@ -1,6 +1,7 @@ use super::LlmGenerationClient; use anyhow::Result; use async_trait::async_trait; +use crate::llm::prompt_utils::STRICT_JSON_PROMPT; use schemars::schema::SchemaObject; use serde::{Deserialize, Serialize}; @@ -52,6 +53,10 @@ impl LlmGenerationClient for Client { &self, request: super::LlmGenerateRequest<'req>, ) -> Result { + let mut system_prompt = request.system_prompt.unwrap_or_default(); + if matches!(request.output_format, Some(super::OutputFormat::JsonSchema { .. })) { + system_prompt = format!("{STRICT_JSON_PROMPT}\n\n{system_prompt}").into(); + } let req = OllamaRequest { model: &self.model, prompt: request.user_prompt.as_ref(), @@ -60,7 +65,7 @@ impl LlmGenerationClient for Client { OllamaFormat::JsonSchema(schema.as_ref()) }, ), - system: request.system_prompt.as_ref().map(|s| s.as_ref()), + system: Some(&system_prompt), stream: Some(false), }; let res = self @@ -71,9 +76,15 @@ impl LlmGenerationClient for Client { .await?; let body = res.text().await?; let json: OllamaResponse = serde_json::from_str(&body)?; - Ok(super::LlmGenerateResponse { - text: json.response, - }) + // Check if output_format is JsonSchema, try to parse as JSON + if let Some(super::OutputFormat::JsonSchema { .. }) = request.output_format { + match serde_json::from_str::(&json.response) { + Ok(val) => Ok(super::LlmGenerateResponse::Json(val)), + Err(_) => Ok(super::LlmGenerateResponse::Text(json.response)), + } + } else { + Ok(super::LlmGenerateResponse::Text(json.response)) + } } fn json_schema_options(&self) -> super::ToJsonSchemaOptions { diff --git a/src/llm/openai.rs b/src/llm/openai.rs index 5675fc86e..f75e1b669 100644 --- a/src/llm/openai.rs +++ b/src/llm/openai.rs @@ -64,8 +64,10 @@ impl LlmGenerationClient for Client { }, )); + // Save output_format before it is moved. + let output_format = request.output_format.clone(); // Create the chat completion request - let request = CreateChatCompletionRequest { + let openai_request = CreateChatCompletionRequest { model: self.model.clone(), messages, response_format: match request.output_format { @@ -85,7 +87,7 @@ impl LlmGenerationClient for Client { }; // Send request and get response - let response = self.client.chat().create(request).await?; + let response = self.client.chat().create(openai_request).await?; // Extract the response text from the first choice let text = response @@ -95,7 +97,15 @@ impl LlmGenerationClient for Client { .and_then(|choice| choice.message.content) .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?; - Ok(super::LlmGenerateResponse { text }) + // If output_format is JsonSchema, try to parse as JSON + if let Some(super::OutputFormat::JsonSchema { .. }) = output_format { + match serde_json::from_str::(&text) { + Ok(val) => Ok(super::LlmGenerateResponse::Json(val)), + Err(_) => Ok(super::LlmGenerateResponse::Text(text)), + } + } else { + Ok(super::LlmGenerateResponse::Text(text)) + } } fn json_schema_options(&self) -> super::ToJsonSchemaOptions { diff --git a/src/llm/prompt_utils.rs b/src/llm/prompt_utils.rs new file mode 100644 index 000000000..fe28f8930 --- /dev/null +++ b/src/llm/prompt_utils.rs @@ -0,0 +1,4 @@ +// Shared prompt utilities for LLM clients +// Only import this in clients that require strict JSON output instructions (e.g., Anthropic, Gemini, Ollama) + +pub const STRICT_JSON_PROMPT: &str = "IMPORTANT: Output ONLY valid JSON that matches the schema. Do NOT say anything else. Do NOT explain. Do NOT preface. Do NOT add comments. If you cannot answer, output an empty JSON object: {}."; diff --git a/src/ops/functions/extract_by_llm.rs b/src/ops/functions/extract_by_llm.rs index 060956c70..7a3abf237 100644 --- a/src/ops/functions/extract_by_llm.rs +++ b/src/ops/functions/extract_by_llm.rs @@ -1,7 +1,7 @@ use crate::prelude::*; use crate::llm::{ - new_llm_generation_client, LlmGenerateRequest, LlmGenerationClient, LlmSpec, OutputFormat, + new_llm_generation_client, LlmGenerateRequest, LlmGenerationClient, LlmGenerateResponse, LlmSpec, OutputFormat, }; use crate::ops::sdk::*; use base::json_schema::build_json_schema; @@ -83,7 +83,10 @@ impl SimpleFunctionExecutor for Executor { }), }; let res = self.client.generate(req).await?; - let json_value: serde_json::Value = serde_json::from_str(res.text.as_str())?; + let json_value = match res { + LlmGenerateResponse::Json(val) => val, + LlmGenerateResponse::Text(text) => serde_json::from_str(&text)?, + }; let value = self.value_extractor.extract_value(json_value)?; Ok(value) } @@ -124,4 +127,4 @@ impl SimpleFunctionFactoryBase for Factory { ) -> Result> { Ok(Box::new(Executor::new(spec, resolved_input_schema).await?)) } -} +} \ No newline at end of file From 552a0b54878f2c2c8ff02b15c5876b210ee88cbb Mon Sep 17 00:00:00 2001 From: Param Arora Date: Wed, 28 May 2025 12:15:17 +0530 Subject: [PATCH 2/2] remove json5 fallback --- Cargo.lock | 1 - Cargo.toml | 1 - 2 files changed, 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 848cbd6dc..c39c6efc0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1023,7 +1023,6 @@ dependencies = [ "indexmap 2.9.0", "indoc", "itertools 0.14.0", - "json5", "log", "neo4rs", "owo-colors", diff --git a/Cargo.toml b/Cargo.toml index 1a19b22ba..24ffcb3fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -108,7 +108,6 @@ bytes = "1.10.1" rand = "0.9.0" indoc = "2.0.6" owo-colors = "4.2.0" -json5 = "0.4.1" aws-config = "1.6.2" aws-sdk-s3 = "1.85.0" aws-sdk-sqs = "1.67.0"