22
33import json
44import os
5+ import typing as t
56import warnings
7+ from dataclasses import dataclass
68from functools import lru_cache
79
10+ from langchain .callbacks .manager import CallbackManager , trace_as_chain_group
11+ from langchain .prompts import ChatPromptTemplate , HumanMessagePromptTemplate
12+
13+ if t .TYPE_CHECKING :
14+ from ragas .llms import RagasLLM
15+
816DEBUG_ENV_VAR = "RAGAS_DEBUG"
917# constant to tell us that there is no key passed to the llm/embeddings
1018NO_KEY = "no-key"
@@ -29,3 +37,119 @@ def load_as_json(text):
2937 warnings .warn (f"Invalid json: { e } " )
3038
3139 return {}
40+
41+
42+ JSON_PROMPT = HumanMessagePromptTemplate .from_template (
43+ """
44+
45+ Rewrite the input into valid json
46+
47+
48+ Input:
49+ {{
50+ "name": "John Doe",
51+ "age": 30,
52+ "isStudent": false
53+ "address": {{
54+ "street": "123 Main St",
55+ "city": "Anytown",
56+ "state": "CA",
57+ }}
58+ "hobbies": ["reading", "swimming", "cycling"]
59+ }}
60+ Output:
61+ {{
62+ "name": "John Doe",
63+ "age": 30,
64+ "isStudent": false,
65+ "address": {{
66+ "street": "123 Main St",
67+ "city": "Anytown",
68+ "state": "CA"
69+ }},
70+ "hobbies": ["reading", "swimming", "cycling"]
71+ }}
72+
73+
74+ Input:
75+ {{
76+ "statement": "The Earth is also known as "Terra" "
77+ }}
78+ Output:
79+ {{
80+ "statement": "The Earth is also known as 'Terra'"
81+ }}
82+
83+ Input:
84+ {input}
85+
86+ Output:
87+ """
88+ )
89+
90+
91+ @dataclass
92+ class JsonLoader :
93+ max_retries : int = 2
94+
95+ def safe_load (self , text : str , llm : RagasLLM ):
96+ retry = 0
97+ while retry <= self .max_retries :
98+ try :
99+ start , end = self ._find_outermost_json (text )
100+ return json .loads (text [start :end ])
101+ except ValueError :
102+ text = self ._fix_to_json (text , llm )
103+ retry += 1
104+
105+ return {}
106+
107+ def _fix_to_json (
108+ self ,
109+ text ,
110+ llm ,
111+ callbacks : t .Optional [CallbackManager ] = None ,
112+ callback_group_name : str = "batch" ,
113+ ):
114+ # TODO (executor)
115+ with trace_as_chain_group (
116+ callback_group_name , callback_manager = callbacks
117+ ) as batch_group :
118+ human_prompt = ChatPromptTemplate .from_messages (
119+ [JSON_PROMPT .format (input = text )]
120+ )
121+ results = llm .generate (
122+ [human_prompt ],
123+ n = 1 ,
124+ callbacks = batch_group ,
125+ )
126+ return results .generations [0 ][0 ].text
127+
128+ def _find_outermost_json (self , text ):
129+ stack = []
130+ start_index = - 1
131+
132+ for i , char in enumerate (text ):
133+ if char in "{[" :
134+ if len (stack ) == 0 :
135+ start_index = i
136+ stack .append (char )
137+
138+ elif char in "}]" :
139+ if len (stack ) > 0 :
140+ last = stack .pop ()
141+ if (char == "}" and last != "{" ) or (char == "]" and last != "[" ):
142+ # Mismatched closing brace/bracket, invalid JSON
143+ break
144+
145+ if len (stack ) == 0 and start_index != - 1 :
146+ # Found a valid outermost JSON
147+ return (
148+ start_index ,
149+ i + 1 ,
150+ ) # Add 1 to include the closing brace/bracket in the range
151+
152+ return - 1 , - 1 # No valid JSON found
153+
154+
155+ json_loader = JsonLoader ()
0 commit comments