Skip to content

Commit bf2b17a

Browse files
authored
Merge pull request #35 from omkarjoshi0304/Dev
feat: Add JSON output support to query_rag.py script
2 parents 2cf406f + f57af14 commit bf2b17a

File tree

1 file changed

+124
-14
lines changed

1 file changed

+124
-14
lines changed

scripts/query_rag.py

Lines changed: 124 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import argparse
44
import importlib
5+
import json
6+
import logging
57
import os
68
import sys
79
import tempfile
@@ -30,23 +32,76 @@ def _llama_index_query(args: argparse.Namespace) -> None:
3032
storage_context=storage_context,
3133
index_id=args.product_index,
3234
)
35+
3336
if args.node is not None:
34-
print(storage_context.docstore.get_node(args.node))
37+
node = storage_context.docstore.get_node(args.node)
38+
result = {
39+
"query": args.query,
40+
"type": "single_node",
41+
"node_id": args.node,
42+
"node": {
43+
"id": node.node_id,
44+
"text": node.text,
45+
"metadata": node.metadata if hasattr(node, 'metadata') else {}
46+
}
47+
}
48+
if args.json:
49+
print(json.dumps(result, indent=2))
50+
else:
51+
print(node)
3552
else:
3653
retriever = vector_index.as_retriever(similarity_top_k=args.top_k)
3754
nodes = retriever.retrieve(args.query)
55+
3856
if len(nodes) == 0:
39-
print(f"No nodes retrieved for query: {args.query}")
57+
logging.warning(f"No nodes retrieved for query: {args.query}")
58+
if args.json:
59+
result = {
60+
"query": args.query,
61+
"top_k": args.top_k,
62+
"threshold": args.threshold,
63+
"nodes": []
64+
}
65+
print(json.dumps(result, indent=2))
4066
exit(1)
67+
4168
if args.threshold > 0.0 and nodes[0].score < args.threshold:
42-
print(
69+
logging.warning(
4370
f"Score {nodes[0].score} of the top retrieved node for query '{args.query}' "
4471
f"didn't cross the minimal threshold {args.threshold}."
4572
)
73+
if args.json:
74+
result = {
75+
"query": args.query,
76+
"top_k": args.top_k,
77+
"threshold": args.threshold,
78+
"nodes": []
79+
}
80+
print(json.dumps(result, indent=2))
4681
exit(1)
47-
for n in nodes:
48-
print("=" * 80)
49-
print(n)
82+
83+
# Format results
84+
result = {
85+
"query": args.query,
86+
"top_k": args.top_k,
87+
"threshold": args.threshold,
88+
"nodes": []
89+
}
90+
for node in nodes:
91+
node_data = {
92+
"id": node.node_id,
93+
"score": node.score,
94+
"text": node.text,
95+
"metadata": node.metadata if hasattr(node, 'metadata') else {}
96+
}
97+
result["nodes"].append(node_data)
98+
99+
if args.json:
100+
print(json.dumps(result, indent=2))
101+
else:
102+
for n in nodes:
103+
print("=" * 80)
104+
print(n)
50105

51106

52107
def _get_db_path_dict(vector_type: str, config: dict[str, Any]) -> dict[str, Any]:
@@ -108,20 +163,56 @@ def _llama_stack_query(args: argparse.Namespace) -> None:
108163

109164
md = res.metadata
110165
if len(md["chunks"]) == 0:
111-
print(f"No chunks retrieved for query: {args.query}")
166+
logging.warning(f"No chunks retrieved for query: {args.query}")
167+
if args.json:
168+
result = {
169+
"query": args.query,
170+
"top_k": args.top_k,
171+
"threshold": args.threshold,
172+
"nodes": []
173+
}
174+
print(json.dumps(result, indent=2))
112175
exit(1)
176+
113177
threshold = args.threshold
114178
if threshold > 0.0 and md.get("scores") and md["scores"][0].score < threshold:
115-
print(
179+
logging.warning(
116180
f"Score {md['scores'][0].score} of the top retrieved node for query '{args.query}' "
117181
f"didn't cross the minimal threshold {threshold}."
118182
)
183+
if args.json:
184+
result = {
185+
"query": args.query,
186+
"top_k": args.top_k,
187+
"threshold": args.threshold,
188+
"nodes": []
189+
}
190+
print(json.dumps(result, indent=2))
119191
exit(1)
120192

121-
# Method 1 to present data:
193+
# Format results
194+
result = {
195+
"query": args.query,
196+
"top_k": args.top_k,
197+
"threshold": args.threshold,
198+
"nodes": []
199+
}
200+
122201
for _id, chunk, score in zip(md["document_ids"], md["chunks"], md["scores"]):
123-
print("=" * 80)
124-
print(f"Node ID: {_id}\nScore: {score}\nText:\n{chunk}")
202+
node_data = {
203+
"id": _id,
204+
"score": score.score if hasattr(score, 'score') else score,
205+
"text": chunk,
206+
"metadata": {}
207+
}
208+
result["nodes"].append(node_data)
209+
210+
if args.json:
211+
print(json.dumps(result, indent=2))
212+
else:
213+
for _id, chunk, score in zip(md["document_ids"], md["chunks"], md["scores"]):
214+
print("=" * 80)
215+
print(f"Node ID: {_id}\nScore: {score}\nText:\n{chunk}")
125216

126217
# Method 2 to present data:
127218
# for content in res.content:
@@ -130,7 +221,6 @@ def _llama_stack_query(args: argparse.Namespace) -> None:
130221
# else:
131222
# print(content)
132223

133-
134224
if __name__ == "__main__":
135225
parser = argparse.ArgumentParser(
136226
description="Utility script for querying RAG database"
@@ -161,10 +251,30 @@ def _llama_stack_query(args: argparse.Namespace) -> None:
161251
choices=["auto", "faiss", "llamastack-faiss", "llamastack-sqlite-vec"],
162252
help="vector store type to be used.",
163253
)
254+
parser.add_argument(
255+
"--json",
256+
action="store_true",
257+
help="Output results in JSON format",
258+
)
164259

165260
args = parser.parse_args()
166261

167-
print("Command line used: " + " ".join(sys.argv))
262+
if args.json:
263+
# In JSON mode, only show ERROR or higher to avoid polluting JSON output
264+
logging.basicConfig(
265+
level=logging.ERROR,
266+
format='%(levelname)s: %(message)s',
267+
stream=sys.stderr # Send logs to stderr to keep stdout clean for JSON
268+
)
269+
else:
270+
# In normal mode, show info and above
271+
logging.basicConfig(
272+
level=logging.INFO,
273+
format='%(message)s'
274+
)
275+
276+
if not args.json:
277+
logging.info("Command line used: " + " ".join(sys.argv))
168278

169279
vector_store_type = args.vector_store_type
170280
if args.vector_store_type == "auto":
@@ -175,7 +285,7 @@ def _llama_stack_query(args: argparse.Namespace) -> None:
175285
elif os.path.exists(os.path.join(args.db_path, "faiss_store.db")):
176286
args.vector_store_type = "llamastack-faiss"
177287
else:
178-
print("Cannot recognize the DB in", args.db_path)
288+
logging.error(f"Cannot recognize the DB in {args.db_path}")
179289
exit(1)
180290

181291
if args.vector_store_type == "faiss":

0 commit comments

Comments
 (0)