Skip to content

Commit b988f53

Browse files
authored
fix: added a test for prompts (#1197)
1 parent a77c475 commit b988f53

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

src/ragas/llms/prompt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def get_all_keys(nested_json):
269269

270270
return self
271271

272-
def save(self, cache_dir: t.Optional[str] = None) -> None:
272+
def save(self, cache_dir: t.Optional[str] = None):
273273
cache_dir = cache_dir if cache_dir else get_cache_dir()
274274
cache_dir = os.path.join(cache_dir, self.language)
275275
if not os.path.exists(cache_dir):

tests/unit/llms/test_prompt.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,12 @@ def test_prompt_object_names():
121121
obj.name not in prompt_object_names
122122
), f"Duplicate prompt name: {obj.name}"
123123
prompt_object_names.append(obj.name)
124+
125+
126+
def test_save_and_load(tmp_path):
127+
for testcase in TESTCASES:
128+
prompt = Prompt(**testcase)
129+
prompt.save(tmp_path)
130+
loaded_prompt = prompt._load(prompt.language, prompt.name, tmp_path)
131+
132+
assert prompt == loaded_prompt

0 commit comments

Comments
 (0)