Skip to content

Commit 8ea93fd

Browse files
committed
Support deepseek
1 parent 608e446 commit 8ea93fd

File tree

4 files changed

+91
-31
lines changed

4 files changed

+91
-31
lines changed

aiscript-vm/src/ai/agent.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ use openai_api_rs::v1::{
88
ChatCompletionMessage, ChatCompletionMessageForResponse, ChatCompletionRequest, Content,
99
MessageRole, Tool, ToolCall, ToolChoiceType, ToolType,
1010
},
11-
common::GPT3_5_TURBO,
1211
types::{self, FunctionParameters, JSONSchemaDefine},
1312
};
1413
use tokio::runtime::Handle;
@@ -278,6 +277,8 @@ pub async fn _run_agent<'gc>(
278277
mut agent: Gc<'gc, Agent<'gc>>,
279278
args: Vec<Value<'gc>>,
280279
) -> Value<'gc> {
280+
use super::default_model;
281+
281282
let message = args[0];
282283
let debug = args[1].as_boolean();
283284
let mut history = Vec::new();
@@ -288,11 +289,11 @@ pub async fn _run_agent<'gc>(
288289
tool_calls: None,
289290
tool_call_id: None,
290291
});
291-
let mut client = super::openai_client();
292+
let mut client = super::openai_client(state.ai_config.as_ref());
292293
loop {
293294
let mut messages = vec![agent.get_instruction_message()];
294295
messages.extend(history.clone());
295-
let mut req = ChatCompletionRequest::new(GPT3_5_TURBO.to_string(), messages);
296+
let mut req = ChatCompletionRequest::new(default_model(state.ai_config.as_ref()), messages);
296297
let tools = agent.get_tools();
297298
if !tools.is_empty() {
298299
req = req

aiscript-vm/src/ai/mod.rs

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,22 @@ mod prompt;
44
use std::env;
55

66
pub use agent::{Agent, run_agent};
7-
use openai_api_rs::v1::api::OpenAIClient;
7+
use openai_api_rs::v1::{api::OpenAIClient, common::GPT3_5_TURBO};
88
pub use prompt::{PromptConfig, prompt_with_config};
99

1010
use serde::Deserialize;
1111

12-
#[derive(Debug, Clone, Deserialize, Default)]
13-
pub struct AiConfig {
14-
pub openai: Option<ModelConfig>,
15-
pub anthropic: Option<ModelConfig>,
16-
pub deepseek: Option<ModelConfig>,
12+
const DEEPSEEK_API_ENDPOINT: &str = "https://api.deepseek.com/v1";
13+
const DEEPSEEK_V3: &str = "deepseek-chat";
14+
15+
#[derive(Debug, Clone, Deserialize)]
16+
pub enum AiConfig {
17+
#[serde(rename = "openai")]
18+
OpenAI(ModelConfig),
19+
#[serde(rename = "anthropic")]
20+
Anthropic(ModelConfig),
21+
#[serde(rename = "deepseek")]
22+
DeepSeek(ModelConfig),
1723
}
1824

1925
#[derive(Debug, Clone, Deserialize)]
@@ -22,10 +28,52 @@ pub struct ModelConfig {
2228
pub model: Option<String>,
2329
}
2430

31+
impl AiConfig {
32+
pub(crate) fn take_model(&mut self) -> Option<String> {
33+
match self {
34+
Self::OpenAI(ModelConfig { model, .. }) => model.take(),
35+
Self::Anthropic(ModelConfig { model, .. }) => model.take(),
36+
Self::DeepSeek(ModelConfig { model, .. }) => model.take(),
37+
}
38+
}
39+
}
40+
2541
#[allow(unused)]
26-
pub(crate) fn openai_client() -> OpenAIClient {
27-
OpenAIClient::builder()
28-
.with_api_key(env::var("OPENAI_API_KEY").unwrap().to_string())
29-
.build()
30-
.unwrap()
42+
pub(crate) fn openai_client(config: Option<&AiConfig>) -> OpenAIClient {
43+
match config {
44+
None => OpenAIClient::builder()
45+
.with_api_key(env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"))
46+
.build()
47+
.unwrap(),
48+
Some(AiConfig::OpenAI(model_config)) => {
49+
let api_key = if model_config.api_key.is_empty() {
50+
env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set")
51+
} else {
52+
model_config.api_key.clone()
53+
};
54+
OpenAIClient::builder()
55+
.with_api_key(api_key)
56+
.build()
57+
.unwrap()
58+
}
59+
Some(AiConfig::DeepSeek(ModelConfig { api_key, .. })) => OpenAIClient::builder()
60+
.with_endpoint(DEEPSEEK_API_ENDPOINT)
61+
.with_api_key(api_key)
62+
.build()
63+
.unwrap(),
64+
Some(AiConfig::Anthropic(_)) => unimplemented!("Anthropic API not yet supported"),
65+
}
66+
}
67+
68+
pub(crate) fn default_model(config: Option<&AiConfig>) -> String {
69+
match config {
70+
None => GPT3_5_TURBO.to_string(),
71+
Some(AiConfig::OpenAI(ModelConfig { model, .. })) => {
72+
model.clone().unwrap_or(GPT3_5_TURBO.to_string())
73+
}
74+
Some(AiConfig::DeepSeek(ModelConfig { model, .. })) => {
75+
model.clone().unwrap_or(DEEPSEEK_V3.to_string())
76+
}
77+
Some(AiConfig::Anthropic(_)) => unimplemented!("Anthropic API not yet supported"),
78+
}
3179
}

aiscript-vm/src/ai/prompt.rs

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
use openai_api_rs::v1::common::GPT3_5_TURBO;
22
use tokio::runtime::Handle;
33

4+
use super::{AiConfig, ModelConfig, default_model};
5+
46
pub struct PromptConfig {
57
pub input: String,
6-
pub model: Option<String>,
8+
pub ai_config: Option<AiConfig>,
79
pub max_tokens: Option<i64>,
810
pub temperature: Option<f64>,
911
pub system_prompt: Option<String>,
@@ -13,27 +15,38 @@ impl Default for PromptConfig {
1315
fn default() -> Self {
1416
Self {
1517
input: String::new(),
16-
model: Some(GPT3_5_TURBO.to_string()),
18+
ai_config: Some(AiConfig::OpenAI(ModelConfig {
19+
api_key: Default::default(),
20+
model: Some(GPT3_5_TURBO.to_string()),
21+
})),
1722
max_tokens: Default::default(),
1823
temperature: Default::default(),
1924
system_prompt: Default::default(),
2025
}
2126
}
2227
}
2328

29+
impl PromptConfig {
30+
fn take_model(&mut self) -> String {
31+
self.ai_config
32+
.as_mut()
33+
.and_then(|config| config.take_model())
34+
.unwrap_or_else(|| default_model(self.ai_config.as_ref()))
35+
}
36+
37+
pub(crate) fn set_model(&mut self, model: String) {}
38+
}
39+
2440
#[cfg(feature = "ai_test")]
2541
async fn _prompt_with_config(config: PromptConfig) -> String {
2642
return format!("AI: {}", config.input);
2743
}
2844

2945
#[cfg(not(feature = "ai_test"))]
3046
async fn _prompt_with_config(mut config: PromptConfig) -> String {
31-
use openai_api_rs::v1::{
32-
chat_completion::{self, ChatCompletionRequest},
33-
common::GPT3_5_TURBO,
34-
};
35-
36-
let mut client = super::openai_client();
47+
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
48+
let mut client = super::openai_client(config.ai_config.as_ref());
49+
let model = config.take_model();
3750

3851
// Create system message if provided
3952
let mut messages = Vec::new();
@@ -57,13 +70,7 @@ async fn _prompt_with_config(mut config: PromptConfig) -> String {
5770
});
5871

5972
// Build the request
60-
let mut req = ChatCompletionRequest::new(
61-
config
62-
.model
63-
.take()
64-
.unwrap_or_else(|| GPT3_5_TURBO.to_string()),
65-
messages,
66-
);
73+
let mut req = ChatCompletionRequest::new(model, messages);
6774

6875
if let Some(max_tokens) = config.max_tokens {
6976
req.max_tokens = Some(max_tokens);

aiscript-vm/src/vm/state.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,12 +1016,16 @@ impl<'gc> State<'gc> {
10161016
let input = s.to_str().unwrap().to_string();
10171017
ai::prompt_with_config(PromptConfig {
10181018
input,
1019+
ai_config: self.ai_config.clone(),
10191020
..Default::default()
10201021
})
10211022
}
10221023
// Object config case
10231024
Value::Object(obj) => {
1024-
let mut config = PromptConfig::default();
1025+
let mut config = PromptConfig {
1026+
ai_config: self.ai_config.clone(),
1027+
..Default::default()
1028+
};
10251029
let obj_ref = obj.borrow();
10261030

10271031
// Extract input (required)
@@ -1039,7 +1043,7 @@ impl<'gc> State<'gc> {
10391043
if let Some(Value::String(model)) =
10401044
obj_ref.fields.get(&self.intern(b"model"))
10411045
{
1042-
config.model = Some(model.to_str().unwrap().to_string());
1046+
config.set_model(model.to_str().unwrap().to_string());
10431047
}
10441048

10451049
// Extract max_tokens (optional)

0 commit comments

Comments
 (0)