Skip to content

Commit c4ed989

Browse files
HerrIvanIvan Herreros
andauthored
fix: temperature parameter in generate_text not ignored. (#887)
Addresses #886. Temperature parameter is not anymore overwritten by calling to `self.get_temperature(n=n)`. Now it will only call that method if no parameter was given. --------- Co-authored-by: Ivan Herreros <[email protected]>
1 parent 1c3b6d3 commit c4ed989

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

src/ragas/llms/base.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ async def agenerate_text(
7070
self,
7171
prompt: PromptValue,
7272
n: int = 1,
73-
temperature: float = 1e-8,
73+
temperature: t.Optional[float] = None,
7474
stop: t.Optional[t.List[str]] = None,
7575
callbacks: Callbacks = None,
7676
) -> LLMResult:
@@ -80,11 +80,13 @@ async def generate(
8080
self,
8181
prompt: PromptValue,
8282
n: int = 1,
83-
temperature: float = 1e-8,
83+
temperature: t.Optional[float] = None,
8484
stop: t.Optional[t.List[str]] = None,
8585
callbacks: Callbacks = None,
8686
is_async: bool = True,
8787
) -> LLMResult:
88+
if temperature is None:
89+
temperature = 1e-8
8890
"""Generate text using the given event loop."""
8991
if is_async:
9092
agenerate_text_with_retry = add_async_retry(
@@ -131,11 +133,14 @@ def generate_text(
131133
self,
132134
prompt: PromptValue,
133135
n: int = 1,
134-
temperature: float = 1e-8,
136+
temperature: t.Optional[float] = None,
135137
stop: t.Optional[t.List[str]] = None,
136138
callbacks: Callbacks = None,
137139
) -> LLMResult:
138-
temperature = self.get_temperature(n=n)
140+
# figure out the temperature to set
141+
if temperature is None:
142+
temperature = self.get_temperature(n=n)
143+
139144
if is_multiple_completion_supported(self.langchain_llm):
140145
return self.langchain_llm.generate_prompt(
141146
prompts=[prompt],
@@ -161,11 +166,12 @@ async def agenerate_text(
161166
self,
162167
prompt: PromptValue,
163168
n: int = 1,
164-
temperature: float = 1e-8,
169+
temperature: t.Optional[float] = None,
165170
stop: t.Optional[t.List[str]] = None,
166171
callbacks: Callbacks = None,
167172
) -> LLMResult:
168-
temperature = self.get_temperature(n=n)
173+
if temperature is None:
174+
temperature = self.get_temperature(n=n)
169175
if is_multiple_completion_supported(self.langchain_llm):
170176
return await self.langchain_llm.agenerate_prompt(
171177
prompts=[prompt],

0 commit comments

Comments
 (0)