Skip to content

Commit 6136e12

Browse files
authored
fix: temperature change to 0.01 bugs (#2247)
1 parent 82344ab commit 6136e12

File tree

7 files changed

+30
-9
lines changed

7 files changed

+30
-9
lines changed

src/ragas/llms/base.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ def generate_text(
8080
temperature: float = 0.01,
8181
stop: t.Optional[t.List[str]] = None,
8282
callbacks: Callbacks = None,
83-
) -> LLMResult:
84-
...
83+
) -> LLMResult: ...
8584

8685
@abstractmethod
8786
async def agenerate_text(
@@ -91,7 +90,11 @@ async def agenerate_text(
9190
temperature: t.Optional[float] = 0.01,
9291
stop: t.Optional[t.List[str]] = None,
9392
callbacks: Callbacks = None,
94-
) -> LLMResult:
93+
) -> LLMResult: ...
94+
95+
@abstractmethod
96+
def is_finished(self, response: LLMResult) -> bool:
97+
"""Check if the LLM response is finished/complete."""
9598
...
9699

97100
async def generate(
@@ -335,7 +338,7 @@ def check_args(
335338
) -> dict[str, t.Any]:
336339
if n != 1:
337340
logger.warning("n values greater than 1 not support for LlamaIndex LLMs")
338-
if temperature != 1e-8:
341+
if temperature != 0.01:
339342
logger.info("temperature kwarg passed to LlamaIndex LLM")
340343
if stop is not None:
341344
logger.info("stop kwarg passed to LlamaIndex LLM")
@@ -359,7 +362,7 @@ def generate_text(
359362
self,
360363
prompt: PromptValue,
361364
n: int = 1,
362-
temperature: float = 1e-8,
365+
temperature: float = 0.01,
363366
stop: t.Optional[t.List[str]] = None,
364367
callbacks: Callbacks = None,
365368
) -> LLMResult:

src/ragas/llms/haystack_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def generate_text(
8484
self,
8585
prompt: PromptValue,
8686
n: int = 1,
87-
temperature: float = 1e-8,
87+
temperature: float = 0.01,
8888
stop: t.Optional[t.List[str]] = None,
8989
callbacks: t.Optional[Callbacks] = None,
9090
) -> LLMResult:

tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ async def agenerate_text( # type: ignore
4848
) -> LLMResult:
4949
return LLMResult(generations=[[Generation(text=prompt.to_string())]])
5050

51+
def is_finished(self, response: LLMResult) -> bool:
52+
return True
53+
5154

5255
class EchoEmbedding(BaseRagasEmbeddings):
5356
async def aembed_documents(self, texts: t.List[str]) -> t.List[t.List[float]]:

tests/unit/llms/test_llm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def generate_text(
1818
self,
1919
prompt: PromptValue,
2020
n=1,
21-
temperature: float = 1e-8,
21+
temperature: float = 0.01,
2222
stop=None,
2323
callbacks=[],
2424
):
@@ -29,9 +29,12 @@ async def agenerate_text(
2929
self,
3030
prompt: PromptValue,
3131
n=1,
32-
temperature: t.Optional[float] = 1e-8,
32+
temperature: t.Optional[float] = 0.01,
3333
stop=None,
3434
callbacks=[],
3535
):
36-
temp_val = temperature if temperature is not None else 1e-8
36+
temp_val = temperature if temperature is not None else 0.01
3737
return self.generate_text(prompt, n, temp_val, stop, callbacks)
38+
39+
def is_finished(self, response: LLMResult) -> bool:
40+
return True

tests/unit/test_analytics.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ async def agenerate_text( # type: ignore
2626
) -> LLMResult:
2727
return LLMResult(generations=[[Generation(text=prompt.to_string())]])
2828

29+
def is_finished(self, response: LLMResult) -> bool:
30+
return True
31+
2932

3033
def test_debug_tracking_flag(monkeypatch):
3134
import os

tests/unit/test_import.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ def mocked_import(name, *args, **kwargs):
3636
HaystackLLMWrapper(haystack_generator=None)
3737

3838

39+
@pytest.mark.filterwarnings(
40+
"ignore:LangchainEmbeddingsWrapper is deprecated:DeprecationWarning"
41+
)
42+
@pytest.mark.filterwarnings(
43+
"ignore:LlamaIndexEmbeddingsWrapper is deprecated:DeprecationWarning"
44+
)
3945
def test_wrappers_with_missing_haystack(monkeypatch):
4046
"""Simulate missing 'haystack' and verify that:
4147
- Non-Haystack wrappers import and instantiate without error.

tests/unit/test_prompt.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ async def agenerate_text( # type: ignore
2929
) -> LLMResult:
3030
return LLMResult(generations=[[Generation(text=prompt.to_string())]])
3131

32+
def is_finished(self, response: LLMResult) -> bool:
33+
return True
34+
3235

3336
@pytest.mark.asyncio
3437
async def test_string_prompt():

0 commit comments

Comments
 (0)