Skip to content

Commit af5e7a2

Browse files
committed
fix: OpenAIClient parameter from model_name to model to resolve key mismatch
1 parent 866dc7c commit af5e7a2

File tree

6 files changed

+10
-10
lines changed

6 files changed

+10
-10
lines changed

baselines/Genie/genie.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ async def process_chunk(content: str):
122122
load_dotenv()
123123

124124
llm_client = OpenAIClient(
125-
model_name=os.getenv("SYNTHESIZER_MODEL"),
125+
model=os.getenv("SYNTHESIZER_MODEL"),
126126
api_key=os.getenv("SYNTHESIZER_API_KEY"),
127127
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
128128
)

baselines/LongForm/longform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ async def process_chunk(content: str):
8989
load_dotenv()
9090

9191
llm_client = OpenAIClient(
92-
model_name=os.getenv("SYNTHESIZER_MODEL"),
92+
model=os.getenv("SYNTHESIZER_MODEL"),
9393
api_key=os.getenv("SYNTHESIZER_API_KEY"),
9494
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
9595
)

baselines/SELF-QA/self-qa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ async def process_chunk(content: str):
156156
load_dotenv()
157157

158158
llm_client = OpenAIClient(
159-
model_name=os.getenv("SYNTHESIZER_MODEL"),
159+
model=os.getenv("SYNTHESIZER_MODEL"),
160160
api_key=os.getenv("SYNTHESIZER_API_KEY"),
161161
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
162162
)

baselines/Wrap/wrap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ async def process_chunk(content: str):
109109
load_dotenv()
110110

111111
llm_client = OpenAIClient(
112-
model_name=os.getenv("SYNTHESIZER_MODEL"),
112+
model=os.getenv("SYNTHESIZER_MODEL"),
113113
api_key=os.getenv("SYNTHESIZER_API_KEY"),
114114
base_url=os.getenv("SYNTHESIZER_BASE_URL"),
115115
)

graphgen/models/llm/api/openai_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class OpenAIClient(BaseLLMWrapper):
3232
def __init__(
3333
self,
3434
*,
35-
model_name: str = "gpt-4o-mini",
35+
model: str = "gpt-4o-mini",
3636
api_key: Optional[str] = None,
3737
base_url: Optional[str] = None,
3838
json_mode: bool = False,
@@ -44,7 +44,7 @@ def __init__(
4444
**kwargs: Any,
4545
):
4646
super().__init__(**kwargs)
47-
self.model_name = model_name
47+
self.model = model
4848
self.api_key = api_key
4949
self.base_url = base_url
5050
self.json_mode = json_mode
@@ -109,7 +109,7 @@ async def generate_topk_per_token(
109109
kwargs["max_tokens"] = 1
110110

111111
completion = await self.client.chat.completions.create( # pylint: disable=E1125
112-
model=self.model_name, **kwargs
112+
model=self.model, **kwargs
113113
)
114114

115115
tokens = get_top_response_tokens(completion)
@@ -141,7 +141,7 @@ async def generate_answer(
141141
await self.tpm.wait(estimated_tokens, silent=True)
142142

143143
completion = await self.client.chat.completions.create( # pylint: disable=E1125
144-
model=self.model_name, **kwargs
144+
model=self.model, **kwargs
145145
)
146146
if hasattr(completion, "usage"):
147147
self.token_usage.append(

webui/app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
4242

4343
tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
4444
synthesizer_llm_client = OpenAIClient(
45-
model_name=env.get("SYNTHESIZER_MODEL", ""),
45+
model=env.get("SYNTHESIZER_MODEL", ""),
4646
base_url=env.get("SYNTHESIZER_BASE_URL", ""),
4747
api_key=env.get("SYNTHESIZER_API_KEY", ""),
4848
request_limit=True,
@@ -51,7 +51,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
5151
tokenizer=tokenizer_instance,
5252
)
5353
trainee_llm_client = OpenAIClient(
54-
model_name=env.get("TRAINEE_MODEL", ""),
54+
model=env.get("TRAINEE_MODEL", ""),
5555
base_url=env.get("TRAINEE_BASE_URL", ""),
5656
api_key=env.get("TRAINEE_API_KEY", ""),
5757
request_limit=True,

0 commit comments

Comments
 (0)