Skip to content

Commit dfa38b9

Browse files
authored
fix the generator modules (#1153)
1 parent 63a4503 commit dfa38b9

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

autorag/autorag/nodes/generator/llama_index_llm.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import List, Tuple, Union
23

34
import pandas as pd
@@ -13,7 +14,10 @@
1314
pop_params,
1415
is_chat_prompt,
1516
)
16-
from llama_index.core.llms import ChatMessage
17+
from llama_index.core.llms import ChatMessage, ChatResponse
18+
19+
20+
logger = logging.getLogger("AutoRAG")
1721

1822

1923
class LlamaIndexLLM(BaseGenerator):
@@ -115,17 +119,20 @@ def __pure_chat(
115119
]
116120
tasks = [self.llm_instance.achat(msg) for msg in llama_index_messages]
117121
loop = get_event_loop()
118-
results: List[ChatMessage] = loop.run_until_complete(
122+
results: List[ChatResponse] = loop.run_until_complete(
119123
process_batch(tasks, batch_size=self.batch)
120124
)
121125

122126
generated_texts = [res.message.content for res in results]
123127
# Check is there a logprob available
124-
if results[0].logprobs is not None:
128+
if all(res.logprobs is not None for res in results):
125129
retrieved_logprobs = [res.logprobs for res in results]
126130
tokenized_ids = [logprob.token for logprob in retrieved_logprobs]
127131
logprobs = [logprob.logprob for logprob in retrieved_logprobs]
128132
else:
133+
logger.warning(
134+
"Logprobs are not available from the LLM. So, returning pesudo logprobs."
135+
)
129136
tokenized_ids = self.get_default_tokenized_ids(generated_texts)
130137
logprobs = self.get_default_log_probs(tokenized_ids)
131138

autorag/autorag/nodes/generator/openai_llm.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"o1-mini-2024-09-12": 128_000,
3737
"o1-pro": 200_000,
3838
"o1-pro-2025-03-19": 200_000,
39-
"o3": 128_000,
39+
"o3": 200_000,
4040
"o3-mini": 200_000,
4141
"o3-mini-2025-01-31": 200_000,
4242
"o4-mini": 200_000,
@@ -76,7 +76,10 @@ def __init__(self, project_dir, llm: str, batch: int = 16, *args, **kwargs):
7676

7777
client_init_params = pop_params(AsyncOpenAI.__init__, kwargs)
7878
self.client = AsyncOpenAI(**client_init_params)
79-
self.tokenizer = tiktoken.encoding_for_model(self.llm)
79+
try:
80+
self.tokenizer = tiktoken.encoding_for_model(self.llm)
81+
except KeyError:
82+
self.tokenizer = tiktoken.get_encoding("o200k_base")
8083

8184
self.max_token_size = (
8285
MAX_TOKEN_DICT.get(self.llm) - 7
@@ -144,6 +147,7 @@ def _pure(
144147
self.llm.startswith("o1")
145148
or self.llm.startswith("o3")
146149
or self.llm.startswith("o4")
150+
or self.llm.startswith("gpt-5")
147151
):
148152
tasks = [
149153
self.get_result_reasoning(prompt, **openai_chat_params)
@@ -178,7 +182,7 @@ def structured_output(self, prompts: List[str], output_cls, **kwargs):
178182
)
179183
)
180184

181-
openai_chat_params = pop_params(self.client.beta.chat.completions.parse, kwargs)
185+
openai_chat_params = pop_params(self.client.responses.parse, kwargs)
182186
loop = get_event_loop()
183187
tasks = [
184188
self.get_structured_result(prompt, output_cls, **openai_chat_params)
@@ -230,16 +234,13 @@ async def get_structured_result(
230234
]:
231235
raise ValueError("structured output is supported after the gpt-4o model.")
232236

233-
logprobs = True
234-
response = await self.client.beta.chat.completions.parse(
237+
response = await self.client.responses.parse(
235238
model=self.llm,
236-
messages=parse_prompt(prompt),
237-
response_format=output_cls,
238-
logprobs=logprobs,
239-
n=1,
239+
input=parse_prompt(prompt),
240+
text_format=output_cls,
240241
**kwargs,
241242
)
242-
return response.choices[0].message.parsed
243+
return response.output_parsed
243244

244245
async def get_result(self, prompt: Union[str, List[dict]], **kwargs):
245246
logprobs = True
@@ -254,13 +255,11 @@ async def get_result(self, prompt: Union[str, List[dict]], **kwargs):
254255
)
255256
choice = response.choices[0]
256257
answer = choice.message.content
257-
logprobs = list(map(lambda x: x.logprob, choice.logprobs.content))
258-
tokens = list(
259-
map(
260-
lambda x: self.tokenizer.encode(x.token, allowed_special="all")[0],
261-
choice.logprobs.content,
262-
)
263-
)
258+
logprobs = [x.logprob for x in choice.logprobs.content]
259+
tokens = [
260+
self.tokenizer.encode(x.token, allowed_special="all")[0]
261+
for x in choice.logprobs.content
262+
]
264263
if len(tokens) != len(logprobs):
265264
raise ValueError("tokens and logprobs size is different.")
266265
return answer, tokens, logprobs
@@ -270,8 +269,9 @@ async def get_result_reasoning(self, prompt: Union[str, List[dict]], **kwargs):
270269
self.llm.startswith("o1")
271270
or self.llm.startswith("o3")
272271
or self.llm.startswith("o4")
272+
or self.llm.startswith("gpt-5")
273273
):
274-
raise ValueError("get_result_reasoning is only for o1,o3, o4 models.")
274+
raise ValueError("get_result_reasoning is only for o1,o3,o4,gpt-5 models.")
275275
# The default temperature for the o1 model is 1. 1 is only supported.
276276
# See https://platform.openai.com/docs/guides/reasoning about beta limitation of o1 models.
277277
unsupported_params = [

0 commit comments

Comments
 (0)