Skip to content

Commit 97b49e9

Browse files
authored
feat(openai-config): allow custom address, org_id and project_id (#852)
1 parent 9f69850 commit 97b49e9

File tree

4 files changed

+43
-8
lines changed

4 files changed

+43
-8
lines changed

docs/docs/ai/llm.mdx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@ CocoIndex integrates with various LLM APIs for these functions.
7575
To use the OpenAI LLM API, you need to set the environment variable `OPENAI_API_KEY`.
7676
You can generate the API key from [OpenAI Dashboard](https://platform.openai.com/api-keys).
7777

78-
Currently we don't support custom address for OpenAI API.
78+
If you want to use a custom address, you can either provide the `address` parameter in `LlmSpec` / `EmbedText`, or set the environment variable `OPENAI_API_BASE`. The `address` parameter takes precedence over the environment variable.
79+
80+
Spec for OpenAI takes additional `api_config` field, in type `cocoindex.llm.OpenAiConfig` with the following fields:
81+
- `org_id` (type: `str`, optional): The organization ID of the OpenAI account.
82+
- `project_id` (type: `str`, optional): The project ID of the OpenAI account.
7983

8084
You can find the full list of models supported by OpenAI [here](https://platform.openai.com/docs/models).
8185

python/cocoindex/llm.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,21 @@ class VertexAiConfig:
2626
region: str | None = None
2727

2828

29+
@dataclass
30+
class OpenAiConfig:
31+
"""A specification for a OpenAI LLM."""
32+
33+
kind = "OpenAi"
34+
35+
org_id: str | None = None
36+
project_id: str | None = None
37+
38+
2939
@dataclass
3040
class LlmSpec:
3141
"""A specification for a LLM."""
3242

3343
api_type: LlmApiType
3444
model: str
3545
address: str | None = None
36-
api_config: VertexAiConfig | None = None
46+
api_config: VertexAiConfig | OpenAiConfig | None = None

src/llm/mod.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,17 @@ pub struct VertexAiConfig {
2626
pub region: Option<String>,
2727
}
2828

29+
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
30+
pub struct OpenAiConfig {
31+
pub org_id: Option<String>,
32+
pub project_id: Option<String>,
33+
}
34+
2935
#[derive(Debug, Clone, Serialize, Deserialize)]
3036
#[serde(tag = "kind")]
3137
pub enum LlmApiConfig {
3238
VertexAi(VertexAiConfig),
39+
OpenAi(OpenAiConfig),
3340
}
3441

3542
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -113,7 +120,7 @@ pub async fn new_llm_generation_client(
113120
Box::new(ollama::Client::new(address).await?) as Box<dyn LlmGenerationClient>
114121
}
115122
LlmApiType::OpenAi => {
116-
Box::new(openai::Client::new(address)?) as Box<dyn LlmGenerationClient>
123+
Box::new(openai::Client::new(address, api_config)?) as Box<dyn LlmGenerationClient>
117124
}
118125
LlmApiType::Gemini => {
119126
Box::new(gemini::AiStudioClient::new(address)?) as Box<dyn LlmGenerationClient>
@@ -151,7 +158,7 @@ pub async fn new_llm_embedding_client(
151158
Box::new(gemini::AiStudioClient::new(address)?) as Box<dyn LlmEmbeddingClient>
152159
}
153160
LlmApiType::OpenAi => {
154-
Box::new(openai::Client::new(address)?) as Box<dyn LlmEmbeddingClient>
161+
Box::new(openai::Client::new(address, api_config)?) as Box<dyn LlmEmbeddingClient>
155162
}
156163
LlmApiType::Voyage => {
157164
Box::new(voyage::Client::new(address)?) as Box<dyn LlmEmbeddingClient>

src/llm/openai.rs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,31 @@ impl Client {
3333
Self { client }
3434
}
3535

36-
pub fn new(address: Option<String>) -> Result<Self> {
36+
pub fn new(address: Option<String>, api_config: Option<super::LlmApiConfig>) -> Result<Self> {
37+
let config = match api_config {
38+
Some(super::LlmApiConfig::OpenAi(config)) => config,
39+
Some(_) => api_bail!("unexpected config type, expected OpenAiConfig"),
40+
None => super::OpenAiConfig::default(),
41+
};
42+
43+
let mut openai_config = OpenAIConfig::new();
3744
if let Some(address) = address {
38-
api_bail!("OpenAI doesn't support custom API address: {address}");
45+
openai_config = openai_config.with_api_base(address);
46+
}
47+
if let Some(org_id) = config.org_id {
48+
openai_config = openai_config.with_org_id(org_id);
3949
}
50+
if let Some(project_id) = config.project_id {
51+
openai_config = openai_config.with_project_id(project_id);
52+
}
53+
4054
// Verify API key is set
4155
if std::env::var("OPENAI_API_KEY").is_err() {
4256
api_bail!("OPENAI_API_KEY environment variable must be set");
4357
}
4458
Ok(Self {
45-
// OpenAI client will use OPENAI_API_KEY env variable by default
46-
client: OpenAIClient::new(),
59+
// OpenAI client will use OPENAI_API_KEY and OPENAI_API_BASE env variables by default
60+
client: OpenAIClient::with_config(openai_config),
4761
})
4862
}
4963
}

0 commit comments

Comments
 (0)