Skip to content

Commit ffeae6b

Browse files
feat: added support for azure open ai api
1 parent 2024c9d commit ffeae6b

File tree

3 files changed

+39
-8
lines changed

3 files changed

+39
-8
lines changed

.env.example

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ TRAINEE_MODEL=gpt-4o-mini
1414
TRAINEE_BASE_URL=
1515
TRAINEE_API_KEY=
1616

17+
# azure_openai_api
18+
# SYNTHESIZER_BACKEND=azure_openai_api
19+
# The following is the same as your "Deployment name" in Azure
20+
# SYNTHESIZER_MODEL=<your-deployment-name>
21+
# SYNTHESIZER_BASE_URL=https://<your-resource-name>.openai.azure.com/openai/deployments/<your-deployment-name>/chat/completions
22+
# SYNTHESIZER_API_KEY=
23+
# SYNTHESIZER_API_VERSION=<api-version>
24+
1725
# # ollama_api
1826
# SYNTHESIZER_BACKEND=ollama_api
1927
# SYNTHESIZER_MODEL=gemma3

graphgen/models/llm/api/openai_client.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from typing import Any, Dict, List, Optional
33

44
import openai
5-
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, RateLimitError
5+
from openai import APIConnectionError, APITimeoutError, AsyncOpenAI, AsyncAzureOpenAI, RateLimitError
6+
from pyparsing import Literal
67
from tenacity import (
78
retry,
89
retry_if_exception_type,
@@ -35,17 +36,20 @@ def __init__(
3536
model: str = "gpt-4o-mini",
3637
api_key: Optional[str] = None,
3738
base_url: Optional[str] = None,
39+
api_version: Optional[str] = None,
3840
json_mode: bool = False,
3941
seed: Optional[int] = None,
4042
topk_per_token: int = 5, # number of topk tokens to generate for each token
4143
request_limit: bool = False,
4244
rpm: Optional[RPM] = None,
4345
tpm: Optional[TPM] = None,
46+
backend: str = "openai_api",
4447
**kwargs: Any,
4548
):
4649
super().__init__(**kwargs)
4750
self.model = model
4851
self.api_key = api_key
52+
self.api_version = api_version # required for Azure OpenAI
4953
self.base_url = base_url
5054
self.json_mode = json_mode
5155
self.seed = seed
@@ -56,13 +60,31 @@ def __init__(
5660
self.rpm = rpm or RPM()
5761
self.tpm = tpm or TPM()
5862

63+
assert backend in ["openai_api", "azure_openai_api"], f"Unsupported backend {backend}. Use 'openai_api' or 'azure_openai_api'."
64+
self.backend = backend
65+
5966
self.__post_init__()
6067

6168
def __post_init__(self):
62-
assert self.api_key is not None, "Please provide api key to access openai api."
63-
self.client = AsyncOpenAI(
64-
api_key=self.api_key or "dummy", base_url=self.base_url
65-
)
69+
70+
api_name = self.backend.replace("_", " ")
71+
assert self.api_key is not None, f"Please provide api key to access {api_name}."
72+
73+
if self.backend == "openai_api":
74+
self.client = AsyncOpenAI(
75+
api_key=self.api_key or "dummy", base_url=self.base_url
76+
)
77+
elif self.backend == "azure_openai_api":
78+
assert self.api_version is not None, f"Please provide api_version for {api_name}."
79+
assert self.base_url is not None, f"Please provide base_url for {api_name}."
80+
self.client = AsyncAzureOpenAI(
81+
api_key=self.api_key,
82+
azure_endpoint=self.base_url,
83+
api_version=self.api_version,
84+
azure_deployment=self.model,
85+
)
86+
else:
87+
raise ValueError(f"Unsupported backend {self.backend}. Use 'openai_api' or 'azure_openai_api'.")
6688

6789
def _pre_generate(self, text: str, history: List[str]) -> Dict:
6890
kwargs = {

graphgen/operators/init/init_llm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@ def create_llm_wrapper(backend: str, config: Dict[str, Any]) -> BaseLLMWrapper:
2727
from graphgen.models.llm.api.http_client import HTTPClient
2828

2929
return HTTPClient(**config)
30-
if backend == "openai_api":
30+
if backend == "openai_api" or backend == "azure_openai_api":
3131
from graphgen.models.llm.api.openai_client import OpenAIClient
32-
33-
return OpenAIClient(**config)
32+
33+
# pass in concrete backend to the OpenAIClient so that internally we can distinguish between OpenAI and Azure OpenAI
34+
return OpenAIClient(**config, backend=backend)
3435
if backend == "ollama_api":
3536
from graphgen.models.llm.api.ollama_client import OllamaClient
3637

0 commit comments

Comments
 (0)