Skip to content

Commit b4942c1

Browse files
shahules786jjmachanayulockin
authored
fix: import train config and add tests (#1776)
The cache feature causes import of config to fail because of the pydantic schema not found error. Added fix and a test to prevent this in future. --------- Co-authored-by: Jithin James <[email protected]> Co-authored-by: Ayush Thakur <[email protected]>
1 parent 5d99909 commit b4942c1

File tree

3 files changed

+43
-1
lines changed

3 files changed

+43
-1
lines changed

src/ragas/cache.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from abc import ABC, abstractmethod
66
from typing import Any, Optional
77

8-
from pydantic import BaseModel
8+
from pydantic import BaseModel, GetCoreSchemaHandler
9+
from pydantic_core import CoreSchema, core_schema
910

1011

1112
class CacheInterface(ABC):
@@ -21,6 +22,17 @@ def set(self, key: str, value) -> None:
2122
def has_key(self, key: str) -> bool:
2223
pass
2324

25+
@classmethod
26+
def __get_pydantic_core_schema__(
27+
cls, source_type: Any, handler: GetCoreSchemaHandler
28+
) -> CoreSchema:
29+
"""
30+
Define how Pydantic generates a schema for BaseRagasEmbeddings.
31+
"""
32+
return core_schema.no_info_after_validator_function(
33+
cls, core_schema.is_instance_schema(cls) # The validator function
34+
)
35+
2436

2537
class DiskCacheBackend(CacheInterface):
2638
def __init__(self, cache_dir: str = ".cache"):

tests/conftest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import typing as t
44

5+
import numpy as np
56
import pytest
67
from langchain_core.outputs import Generation, LLMResult
78

9+
from ragas.embeddings.base import BaseRagasEmbeddings
810
from ragas.llms.base import BaseRagasLLM
911

1012
if t.TYPE_CHECKING:
@@ -46,6 +48,26 @@ async def agenerate_text( # type: ignore
4648
return LLMResult(generations=[[Generation(text=prompt.to_string())]])
4749

4850

51+
class EchoEmbedding(BaseRagasEmbeddings):
52+
53+
async def aembed_documents(self, texts: t.List[str]) -> t.List[t.List[float]]:
54+
return [np.random.rand(768).tolist() for _ in texts]
55+
56+
async def aembed_query(self, text: str) -> t.List[float]:
57+
return [np.random.rand(768).tolist()]
58+
59+
def embed_documents(self, texts: t.List[str]) -> t.List[t.List[float]]:
60+
return [np.random.rand(768).tolist() for _ in texts]
61+
62+
def embed_query(self, text: str) -> t.List[float]:
63+
return [np.random.rand(768).tolist()]
64+
65+
4966
@pytest.fixture
5067
def fake_llm():
5168
return EchoLLM()
69+
70+
71+
@pytest.fixture
72+
def fake_embedding():
73+
return EchoEmbedding()
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
def test_load_config(fake_llm, fake_embedding):
2+
3+
from ragas.config import DemonstrationConfig, InstructionConfig
4+
5+
inst_config = InstructionConfig(llm=fake_llm)
6+
demo_config = DemonstrationConfig(embedding=fake_embedding)
7+
assert inst_config.llm == fake_llm
8+
assert demo_config.embedding == fake_embedding

0 commit comments

Comments
 (0)