|
22 | 22 | logger = logging.getLogger(__name__) |
23 | 23 |
|
24 | 24 |
|
| 25 | +def patch_save_json(data, filename): |
| 26 | + import json |
| 27 | + import os |
| 28 | + |
| 29 | + path = f"/Users/shahules/Myprojects/ragas/experiments/{filename}.json" |
| 30 | + if os.path.exists(path): |
| 31 | + database = json.load(open(path)) |
| 32 | + database = database if isinstance(database, list) else [database] |
| 33 | + database.append(data) |
| 34 | + else: |
| 35 | + database = [data] |
| 36 | + with open(path, "w") as f: |
| 37 | + json.dump(database, f, indent=4) |
| 38 | + |
| 39 | + |
25 | 40 | @dataclass |
26 | 41 | class Filter(ABC): |
27 | 42 | llm: BaseRagasLLM |
@@ -56,6 +71,8 @@ async def filter(self, node: Node) -> t.Dict: |
56 | 71 | score = await json_loader.safe_load(output, llm=self.llm) |
57 | 72 | score = score if isinstance(score, dict) else {} |
58 | 73 | logger.debug("node filter: %s", score) |
| 74 | + score.update({"context": node.page_content}) |
| 75 | + patch_save_json(score, "node_filter") |
59 | 76 | score.update({"score": score.get("score", 0) >= self.threshold}) |
60 | 77 | return score |
61 | 78 |
|
@@ -87,6 +104,8 @@ async def filter(self, question: str) -> t.Tuple[bool, str]: |
87 | 104 | results = results.generations[0][0].text.strip() |
88 | 105 | json_results = await json_loader.safe_load(results, llm=self.llm) |
89 | 106 | json_results = json_results if isinstance(json_results, dict) else {} |
| 107 | + json_results.update({"question": question}) |
| 108 | + patch_save_json(json_results, "question_filter") |
90 | 109 | logger.debug("filtered question: %s", json_results) |
91 | 110 | return json_results.get("verdict") == "1", json_results.get("feedback", "") |
92 | 111 |
|
@@ -120,6 +139,8 @@ async def filter(self, simple_question: str, compressed_question: str) -> bool: |
120 | 139 | results = results.generations[0][0].text.strip() |
121 | 140 | json_results = await json_loader.safe_load(results, llm=self.llm) |
122 | 141 | json_results = json_results if isinstance(json_results, dict) else {} |
| 142 | + json_results.update({"questions": [simple_question, compressed_question]}) |
| 143 | + patch_save_json(json_results, "evolution_filter") |
123 | 144 | logger.debug("evolution filter: %s", json_results) |
124 | 145 | return json_results.get("verdict") == "1" |
125 | 146 |
|
|
0 commit comments