Skip to content

Commit 6720b36

Browse files
committed
feat: llm generate json response
1 parent ede2b4a commit 6720b36

File tree

7 files changed

+85
-89
lines changed

7 files changed

+85
-89
lines changed

src/llm/anthropic.rs

Lines changed: 18 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
use crate::llm::{
2-
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, LlmSpec, OutputFormat,
3-
ToJsonSchemaOptions,
4-
};
5-
use anyhow::{bail, Context, Result};
61
use async_trait::async_trait;
7-
use json5;
2+
use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat};
3+
use anyhow::{Result, bail, Context};
4+
use crate::llm::prompt_utils::STRICT_JSON_PROMPT;
85
use serde_json::Value;
96

107
use crate::api_bail;
@@ -48,9 +45,11 @@ impl LlmGenerationClient for Client {
4845
});
4946

5047
// Add system prompt as top-level field if present (required)
51-
if let Some(system) = request.system_prompt {
52-
payload["system"] = serde_json::json!(system);
48+
let mut system_prompt = request.system_prompt.unwrap_or_default();
49+
if matches!(request.output_format, Some(OutputFormat::JsonSchema { .. })) {
50+
system_prompt = format!("{STRICT_JSON_PROMPT}\n\n{system_prompt}").into();
5351
}
52+
payload["system"] = serde_json::json!(system_prompt);
5453

5554
// Extract schema from output_format, error if not JsonSchema
5655
let schema = match request.output_format.as_ref() {
@@ -67,69 +66,30 @@ impl LlmGenerationClient for Client {
6766

6867
let encoded_api_key = encode(&self.api_key);
6968

70-
let resp = self
71-
.client
69+
let resp = self.client
7270
.post(url)
7371
.header("x-api-key", encoded_api_key.as_ref())
7472
.header("anthropic-version", "2023-06-01")
7573
.json(&payload)
7674
.send()
7775
.await
7876
.context("HTTP error")?;
79-
let mut resp_json: Value = resp.json().await.context("Invalid JSON")?;
77+
let resp_json: Value = resp.json().await.context("Invalid JSON")?;
8078
if let Some(error) = resp_json.get("error") {
8179
bail!("Anthropic API error: {:?}", error);
8280
}
8381

84-
// Debug print full response
85-
// println!("Anthropic API full response: {resp_json:?}");
86-
87-
let resp_content = &resp_json["content"];
88-
let tool_name = "report_result";
89-
let mut extracted_json: Option<Value> = None;
90-
if let Some(array) = resp_content.as_array() {
91-
for item in array {
92-
if item.get("type") == Some(&Value::String("tool_use".to_string()))
93-
&& item.get("name") == Some(&Value::String(tool_name.to_string()))
94-
{
95-
if let Some(input) = item.get("input") {
96-
extracted_json = Some(input.clone());
97-
break;
98-
}
99-
}
100-
}
101-
}
102-
let text = if let Some(json) = extracted_json {
103-
// Try strict JSON serialization first
104-
serde_json::to_string(&json)?
105-
} else {
106-
// Fallback: try text if no tool output found
107-
match &mut resp_json["content"][0]["text"] {
108-
Value::String(s) => {
109-
// Try strict JSON parsing first
110-
match serde_json::from_str::<serde_json::Value>(s) {
111-
Ok(_) => std::mem::take(s),
112-
Err(e) => {
113-
// Try permissive json5 parsing as fallback
114-
match json5::from_str::<serde_json::Value>(s) {
115-
Ok(value) => {
116-
println!("[Anthropic] Used permissive JSON5 parser for output");
117-
serde_json::to_string(&value)?
118-
},
119-
Err(e2) => return Err(anyhow::anyhow!(format!("No structured tool output or text found in response, and permissive JSON5 parsing also failed: {e}; {e2}")))
120-
}
121-
}
122-
}
123-
}
124-
_ => {
125-
return Err(anyhow::anyhow!(
126-
"No structured tool output or text found in response"
127-
))
128-
}
129-
}
82+
// Extract the text response
83+
let text = match resp_json["content"][0]["text"].as_str() {
84+
Some(s) => s.to_string(),
85+
None => bail!("No text in response"),
13086
};
13187

132-
Ok(LlmGenerateResponse { text })
88+
// Try to parse as JSON
89+
match serde_json::from_str::<serde_json::Value>(&text) {
90+
Ok(val) => Ok(LlmGenerateResponse::Json(val)),
91+
Err(_) => Ok(LlmGenerateResponse::Text(text)),
92+
}
13393
}
13494

13595
fn json_schema_options(&self) -> ToJsonSchemaOptions {

src/llm/gemini.rs

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
use crate::api_bail;
2-
use crate::llm::{
3-
LlmGenerateRequest, LlmGenerateResponse, LlmGenerationClient, LlmSpec, OutputFormat,
4-
ToJsonSchemaOptions,
5-
};
6-
use anyhow::{bail, Context, Result};
71
use async_trait::async_trait;
2+
use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat};
3+
use anyhow::{Result, bail, Context};
84
use serde_json::Value;
5+
use crate::api_bail;
96
use urlencoding::encode;
7+
use crate::llm::prompt_utils::STRICT_JSON_PROMPT;
108

119
pub struct Client {
1210
model: String,
@@ -60,11 +58,14 @@ impl LlmGenerationClient for Client {
6058

6159
// Prepare payload
6260
let mut payload = serde_json::json!({ "contents": contents });
63-
if let Some(system) = request.system_prompt {
64-
payload["systemInstruction"] = serde_json::json!({
65-
"parts": [ { "text": system } ]
66-
});
67-
}
61+
if let Some(mut system) = request.system_prompt {
62+
if matches!(request.output_format, Some(OutputFormat::JsonSchema { .. })) {
63+
system = format!("{STRICT_JSON_PROMPT}\n\n{system}").into();
64+
}
65+
payload["systemInstruction"] = serde_json::json!({
66+
"parts": [ { "text": system } ]
67+
});
68+
}
6869

6970
// If structured output is requested, add schema and responseMimeType
7071
if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format {
@@ -79,13 +80,10 @@ impl LlmGenerationClient for Client {
7980
let api_key = &self.api_key;
8081
let url = format!(
8182
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
82-
encode(&self.model),
83-
encode(api_key)
83+
encode(&self.model), encode(api_key)
8484
);
8585

86-
let resp = self
87-
.client
88-
.post(&url)
86+
let resp = self.client.post(&url)
8987
.json(&payload)
9088
.send()
9189
.await
@@ -102,7 +100,15 @@ impl LlmGenerationClient for Client {
102100
_ => bail!("No text in response"),
103101
};
104102

105-
Ok(LlmGenerateResponse { text })
103+
// If output_format is JsonSchema, try to parse as JSON
104+
if let Some(OutputFormat::JsonSchema { .. }) = request.output_format {
105+
match serde_json::from_str::<serde_json::Value>(&text) {
106+
Ok(val) => Ok(LlmGenerateResponse::Json(val)),
107+
Err(_) => Ok(LlmGenerateResponse::Text(text)),
108+
}
109+
} else {
110+
Ok(LlmGenerateResponse::Text(text))
111+
}
106112
}
107113

108114
fn json_schema_options(&self) -> ToJsonSchemaOptions {
@@ -113,4 +119,4 @@ impl LlmGenerationClient for Client {
113119
top_level_must_be_object: true,
114120
}
115121
}
116-
}
122+
}

src/llm/mod.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ pub struct LlmSpec {
2222
model: String,
2323
}
2424

25-
#[derive(Debug)]
25+
#[derive(Debug, Clone)]
2626
pub enum OutputFormat<'a> {
2727
JsonSchema {
2828
name: Cow<'a, str>,
@@ -38,8 +38,9 @@ pub struct LlmGenerateRequest<'a> {
3838
}
3939

4040
#[derive(Debug)]
41-
pub struct LlmGenerateResponse {
42-
pub text: String,
41+
pub enum LlmGenerateResponse {
42+
Json(serde_json::Value),
43+
Text(String),
4344
}
4445

4546
#[async_trait]
@@ -56,6 +57,7 @@ mod anthropic;
5657
mod gemini;
5758
mod ollama;
5859
mod openai;
60+
mod prompt_utils;
5961

6062
pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGenerationClient>> {
6163
let client = match spec.api_type {

src/llm/ollama.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use super::LlmGenerationClient;
22
use anyhow::Result;
33
use async_trait::async_trait;
4+
use crate::llm::prompt_utils::STRICT_JSON_PROMPT;
45
use schemars::schema::SchemaObject;
56
use serde::{Deserialize, Serialize};
67

@@ -52,6 +53,10 @@ impl LlmGenerationClient for Client {
5253
&self,
5354
request: super::LlmGenerateRequest<'req>,
5455
) -> Result<super::LlmGenerateResponse> {
56+
let mut system_prompt = request.system_prompt.unwrap_or_default();
57+
if matches!(request.output_format, Some(super::OutputFormat::JsonSchema { .. })) {
58+
system_prompt = format!("{STRICT_JSON_PROMPT}\n\n{system_prompt}").into();
59+
}
5560
let req = OllamaRequest {
5661
model: &self.model,
5762
prompt: request.user_prompt.as_ref(),
@@ -60,7 +65,7 @@ impl LlmGenerationClient for Client {
6065
OllamaFormat::JsonSchema(schema.as_ref())
6166
},
6267
),
63-
system: request.system_prompt.as_ref().map(|s| s.as_ref()),
68+
system: Some(&system_prompt),
6469
stream: Some(false),
6570
};
6671
let res = self
@@ -71,9 +76,15 @@ impl LlmGenerationClient for Client {
7176
.await?;
7277
let body = res.text().await?;
7378
let json: OllamaResponse = serde_json::from_str(&body)?;
74-
Ok(super::LlmGenerateResponse {
75-
text: json.response,
76-
})
79+
// Check if output_format is JsonSchema, try to parse as JSON
80+
if let Some(super::OutputFormat::JsonSchema { .. }) = request.output_format {
81+
match serde_json::from_str::<serde_json::Value>(&json.response) {
82+
Ok(val) => Ok(super::LlmGenerateResponse::Json(val)),
83+
Err(_) => Ok(super::LlmGenerateResponse::Text(json.response)),
84+
}
85+
} else {
86+
Ok(super::LlmGenerateResponse::Text(json.response))
87+
}
7788
}
7889

7990
fn json_schema_options(&self) -> super::ToJsonSchemaOptions {

src/llm/openai.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@ impl LlmGenerationClient for Client {
6464
},
6565
));
6666

67+
// Save output_format before it is moved.
68+
let output_format = request.output_format.clone();
6769
// Create the chat completion request
68-
let request = CreateChatCompletionRequest {
70+
let openai_request = CreateChatCompletionRequest {
6971
model: self.model.clone(),
7072
messages,
7173
response_format: match request.output_format {
@@ -85,7 +87,7 @@ impl LlmGenerationClient for Client {
8587
};
8688

8789
// Send request and get response
88-
let response = self.client.chat().create(request).await?;
90+
let response = self.client.chat().create(openai_request).await?;
8991

9092
// Extract the response text from the first choice
9193
let text = response
@@ -95,7 +97,15 @@ impl LlmGenerationClient for Client {
9597
.and_then(|choice| choice.message.content)
9698
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?;
9799

98-
Ok(super::LlmGenerateResponse { text })
100+
// If output_format is JsonSchema, try to parse as JSON
101+
if let Some(super::OutputFormat::JsonSchema { .. }) = output_format {
102+
match serde_json::from_str::<serde_json::Value>(&text) {
103+
Ok(val) => Ok(super::LlmGenerateResponse::Json(val)),
104+
Err(_) => Ok(super::LlmGenerateResponse::Text(text)),
105+
}
106+
} else {
107+
Ok(super::LlmGenerateResponse::Text(text))
108+
}
99109
}
100110

101111
fn json_schema_options(&self) -> super::ToJsonSchemaOptions {

src/llm/prompt_utils.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
// Shared prompt utilities for LLM clients
2+
// Only import this in clients that require strict JSON output instructions (e.g., Anthropic, Gemini, Ollama)
3+
4+
pub const STRICT_JSON_PROMPT: &str = "IMPORTANT: Output ONLY valid JSON that matches the schema. Do NOT say anything else. Do NOT explain. Do NOT preface. Do NOT add comments. If you cannot answer, output an empty JSON object: {}.";

src/ops/functions/extract_by_llm.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::prelude::*;
22

33
use crate::llm::{
4-
new_llm_generation_client, LlmGenerateRequest, LlmGenerationClient, LlmSpec, OutputFormat,
4+
new_llm_generation_client, LlmGenerateRequest, LlmGenerationClient, LlmGenerateResponse, LlmSpec, OutputFormat,
55
};
66
use crate::ops::sdk::*;
77
use base::json_schema::build_json_schema;
@@ -83,7 +83,10 @@ impl SimpleFunctionExecutor for Executor {
8383
}),
8484
};
8585
let res = self.client.generate(req).await?;
86-
let json_value: serde_json::Value = serde_json::from_str(res.text.as_str())?;
86+
let json_value = match res {
87+
LlmGenerateResponse::Json(val) => val,
88+
LlmGenerateResponse::Text(text) => serde_json::from_str(&text)?,
89+
};
8790
let value = self.value_extractor.extract_value(json_value)?;
8891
Ok(value)
8992
}
@@ -124,4 +127,4 @@ impl SimpleFunctionFactoryBase for Factory {
124127
) -> Result<Box<dyn SimpleFunctionExecutor>> {
125128
Ok(Box::new(Executor::new(spec, resolved_input_schema).await?))
126129
}
127-
}
130+
}

0 commit comments

Comments
 (0)