Skip to content

Commit dc1101d

Browse files
committed
Use json-repair package to fix LLM generated json
1 parent bc8540e commit dc1101d

File tree

9 files changed

+1504
-1663
lines changed

9 files changed

+1504
-1663
lines changed

docs/source/api.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,3 +489,10 @@ PipelineStatusUpdateError
489489

490490
.. autoclass:: neo4j_graphrag.experimental.pipeline.exceptions.PipelineStatusUpdateError
491491
:show-inheritance:
492+
493+
494+
JSONRepairError
495+
===============
496+
497+
.. autoclass:: neo4j_graphrag.experimental.pipeline.exceptions.JSONRepairError
498+
:show-inheritance:
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""This example illustrates how to get started easily with the SimpleKGPipeline
2+
and ingest PDF into a Neo4j Knowledge Graph.
3+
4+
This example assumes a Neo4j db is up and running. Update the credentials below
5+
if needed.
6+
7+
It's assumed Ollama is used to run a model locally.
8+
"""
9+
10+
import asyncio
11+
import ollama
12+
from pathlib import Path
13+
14+
import neo4j
15+
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
16+
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
17+
from neo4j_graphrag.llm import LLMInterface, LLMResponse
18+
19+
from llama_index.embeddings.ollama import OllamaEmbedding
20+
from neo4j_graphrag.embeddings.base import Embedder
21+
22+
23+
class OllamaEmbedder(Embedder):
24+
def __init__(self, ollama_embedding: OllamaEmbedding) -> None:
25+
self.embedder = ollama_embedding
26+
27+
def embed_query(self, text: str) -> list[float]:
28+
embedding: list[list[float]] = self.embedder.get_text_embedding_batch(
29+
[text], show_progress=True
30+
)
31+
return embedding[0]
32+
33+
34+
ollama_embedding = OllamaEmbedding(
35+
model_name="qwen2",
36+
base_url="http://localhost:11434",
37+
ollama_additional_kwargs={"mirostat": 0},
38+
)
39+
embedder = OllamaEmbedder(ollama_embedding)
40+
41+
# Neo4j db infos
42+
URI = "neo4j://localhost:7687"
43+
AUTH = ("neo4j", "password")
44+
DATABASE = "neo4j"
45+
46+
47+
root_dir = Path(__file__).parents[4]
48+
file_path = "examples/data/Harry Potter and the Chamber of Secrets Summary.pdf"
49+
50+
51+
# Instantiate Entity and Relation objects. This defines the
52+
# entities and relations the LLM will be looking for in the text.
53+
ENTITIES = ["Person", "Organization", "Location"]
54+
RELATIONS = ["SITUATED_AT", "INTERACTS", "LED_BY"]
55+
POTENTIAL_SCHEMA = [
56+
("Person", "SITUATED_AT", "Location"),
57+
("Person", "INTERACTS", "Person"),
58+
("Organization", "LED_BY", "Person"),
59+
]
60+
61+
62+
async def define_and_run_pipeline(
63+
neo4j_driver: neo4j.Driver,
64+
llm: LLMInterface,
65+
) -> PipelineResult:
66+
# Create an instance of the SimpleKGPipeline
67+
kg_builder = SimpleKGPipeline(
68+
llm=llm,
69+
driver=neo4j_driver,
70+
embedder=embedder,
71+
entities=ENTITIES,
72+
relations=RELATIONS,
73+
potential_schema=POTENTIAL_SCHEMA,
74+
)
75+
return await kg_builder.run_async(file_path=str(file_path))
76+
77+
78+
async def main() -> PipelineResult:
79+
class OllamaLLM(LLMInterface):
80+
def invoke(self, input: str) -> LLMResponse:
81+
response = ollama.chat(
82+
model=self.model_name,
83+
messages=[
84+
{
85+
"role": "user",
86+
"content": input,
87+
},
88+
],
89+
options={"temperature": 0.0},
90+
)
91+
return LLMResponse(content=response["message"]["content"])
92+
93+
async def ainvoke(self, input: str) -> LLMResponse:
94+
return self.invoke(input) # TODO: implement async with ollama.AsyncClient
95+
96+
llm = OllamaLLM("llama3.1")
97+
with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
98+
res = await define_and_run_pipeline(driver, llm)
99+
100+
return res
101+
102+
103+
if __name__ == "__main__":
104+
res = asyncio.run(main())
105+
print(res)

poetry.lock

Lines changed: 1323 additions & 1609 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,13 @@ google-cloud-aiplatform = {version = "^1.66.0", optional = true }
4444
cohere = {version = "^5.9.0", optional = true}
4545
mistralai = {version = "^1.0.3", optional = true}
4646
qdrant-client = {version = "^1.11.3", optional = true}
47-
llama-index = {version = "^0.10.55", optional = true }
4847
openai = {version = "^1.51.1", optional = true }
4948
anthropic = { version = "^0.36.0", optional = true}
5049
sentence-transformers = {version = "^3.0.0", optional = true }
50+
ollama = "^0"
51+
setuptools = "^75.6.0"
52+
llama-index-embeddings-ollama = "^0.4.0"
53+
json-repair = "^0.30.2"
5154

5255
[tool.poetry.group.dev.dependencies]
5356
urllib3 = "<2"

src/neo4j_graphrag/experimental/components/entity_relation_extractor.py

Lines changed: 26 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
import enum
2020
import json
2121
import logging
22-
import re
2322
from datetime import datetime
2423
from typing import Any, List, Optional, Union
2524

25+
import json_repair
26+
2627
from pydantic import ValidationError, validate_call
2728

2829
from neo4j_graphrag.exceptions import LLMGenerationError
@@ -36,6 +37,7 @@
3637
TextChunks,
3738
)
3839
from neo4j_graphrag.experimental.pipeline.component import Component
40+
from neo4j_graphrag.experimental.pipeline.exceptions import JSONRepairError
3941
from neo4j_graphrag.generation.prompts import ERExtractionTemplate, PromptTemplate
4042
from neo4j_graphrag.llm import LLMInterface
4143

@@ -100,28 +102,19 @@ def balance_curly_braces(json_string: str) -> str:
100102
return "".join(fixed_json)
101103

102104

103-
def fix_invalid_json(invalid_json_string: str) -> str:
104-
# Fix missing quotes around field names
105-
invalid_json_string = re.sub(
106-
r"([{,]\s*)(\w+)(\s*:)", r'\1"\2"\3', invalid_json_string
107-
)
108-
109-
# Fix missing quotes around string values, correctly ignoring null, true, false, and numeric values
110-
invalid_json_string = re.sub(
111-
r"(?<=:\s)(?!(null|true|false|\d+\.?\d*))([a-zA-Z_][a-zA-Z0-9_]*)\s*(?=[,}])",
112-
r'"\2"',
113-
invalid_json_string,
114-
)
115-
116-
# Correct the specific issue: remove trailing commas within arrays or objects before closing braces or brackets
117-
invalid_json_string = re.sub(r",\s*(?=[}\]])", "", invalid_json_string)
105+
def fix_invalid_json(raw_json: str) -> str:
106+
repaired_json = json_repair.repair_json(raw_json)
118107

119-
# Normalize excessive curly braces
120-
invalid_json_string = re.sub(r"{{+", "{", invalid_json_string)
121-
invalid_json_string = re.sub(r"}}+", "}", invalid_json_string)
108+
if isinstance(repaired_json, str):
109+
repaired_json = repaired_json.strip()
110+
else:
111+
repaired_json = ""
122112

123-
# Balance curly braces
124-
return balance_curly_braces(invalid_json_string)
113+
if repaired_json.strip() == '""':
114+
raise JSONRepairError("JSON repair resulted in an empty or invalid JSON.")
115+
if not repaired_json.strip():
116+
raise JSONRepairError("JSON repair resulted in an empty string.")
117+
return repaired_json
125118

126119

127120
class EntityRelationExtractor(Component, abc.ABC):
@@ -223,24 +216,18 @@ async def extract_for_chunk(
223216
)
224217
llm_result = await self.llm.ainvoke(prompt)
225218
try:
226-
result = json.loads(llm_result.content)
227-
except json.JSONDecodeError:
228-
logger.info(
229-
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}. Trying to fix it."
230-
)
231-
fixed_content = fix_invalid_json(llm_result.content)
232-
try:
233-
result = json.loads(fixed_content)
234-
except json.JSONDecodeError as e:
235-
if self.on_error == OnError.RAISE:
236-
raise LLMGenerationError(
237-
f"LLM response is not valid JSON {fixed_content}: {e}"
238-
)
239-
else:
240-
logger.error(
241-
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}"
242-
)
243-
result = {"nodes": [], "relationships": []}
219+
llm_generated_json = fix_invalid_json(llm_result.content)
220+
result = json.loads(llm_generated_json)
221+
except (json.JSONDecodeError, JSONRepairError) as e:
222+
if self.on_error == OnError.RAISE:
223+
raise LLMGenerationError(
224+
f"LLM response is not valid JSON {llm_result.content}: {e}"
225+
)
226+
else:
227+
logger.error(
228+
f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}"
229+
)
230+
result = {"nodes": [], "relationships": []}
244231
try:
245232
chunk_graph = Neo4jGraph(**result)
246233
except ValidationError as e:

src/neo4j_graphrag/experimental/pipeline/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,9 @@ class PipelineStatusUpdateError(Neo4jGraphRagError):
3131
"""Raises when trying an invalid change of state (e.g. DONE => DOING)"""
3232

3333
pass
34+
35+
36+
class JSONRepairError(Neo4jGraphRagError):
37+
"""Raised when JSON repair fails to produce valid JSON."""
38+
39+
pass

src/neo4j_graphrag/generation/prompts.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,11 @@ class ERExtractionTemplate(PromptTemplate):
174174
Do respect the source and target node types for relationship and
175175
the relationship direction.
176176
177-
Do not return any additional information other than the JSON in it.
177+
Make sure you adhere to the following rules to produce valid JSON objects:
178+
- Do not return any additional information other than the JSON in it.
179+
- Omit any backticks around the JSON - simply output the JSON on its own.
180+
- The JSON object must not wrapped into a list - it is its own JSON object.
181+
- Property names must be enclosed in double quotes
178182
179183
Examples:
180184
{examples}

src/neo4j_graphrag/llm/mistralai_llm.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,11 @@ def invoke(self, input: str) -> LLMResponse:
8484
messages=self.get_messages(input),
8585
**self.model_params,
8686
)
87-
if response is None or response.choices is None or not response.choices:
88-
content = ""
89-
else:
90-
content = response.choices[0].message.content or ""
87+
content: str = ""
88+
if response and response.choices:
89+
possible_content = response.choices[0].message.content
90+
if isinstance(possible_content, str):
91+
content = possible_content
9192
return LLMResponse(content=content)
9293
except SDKError as e:
9394
raise LLMGenerationError(e)
@@ -111,10 +112,11 @@ async def ainvoke(self, input: str) -> LLMResponse:
111112
messages=self.get_messages(input),
112113
**self.model_params,
113114
)
114-
if response is None or response.choices is None or not response.choices:
115-
content = ""
116-
else:
117-
content = response.choices[0].message.content or ""
115+
content: str = ""
116+
if response and response.choices:
117+
possible_content = response.choices[0].message.content
118+
if isinstance(possible_content, str):
119+
content = possible_content
118120
return LLMResponse(content=content)
119121
except SDKError as e:
120122
raise LLMGenerationError(e)

tests/unit/experimental/components/test_entity_relation_extractor.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
import json
18-
from unittest.mock import MagicMock
18+
from unittest.mock import MagicMock, patch
1919

2020
import pytest
2121
from neo4j_graphrag.exceptions import LLMGenerationError
@@ -31,6 +31,7 @@
3131
TextChunk,
3232
TextChunks,
3333
)
34+
from neo4j_graphrag.experimental.pipeline.exceptions import JSONRepairError
3435
from neo4j_graphrag.llm import LLMInterface, LLMResponse
3536

3637

@@ -154,8 +155,8 @@ async def test_extractor_llm_badly_formatted_json() -> None:
154155
llm=llm,
155156
)
156157
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
157-
with pytest.raises(LLMGenerationError):
158-
await extractor.run(chunks=chunks)
158+
159+
await extractor.run(chunks=chunks)
159160

160161

161162
@pytest.mark.asyncio
@@ -177,7 +178,7 @@ async def test_extractor_llm_invalid_json() -> None:
177178

178179

179180
@pytest.mark.asyncio
180-
async def test_extractor_llm_badly_formatted_json_do_not_raise() -> None:
181+
async def test_extractor_llm_badly_formatted_json_gets_fixed() -> None:
181182
llm = MagicMock(spec=LLMInterface)
182183
llm.ainvoke.return_value = LLMResponse(
183184
content='{"nodes": [{"id": "0", "label": "Person", "properties": {}}], "relationships": [}'
@@ -190,7 +191,11 @@ async def test_extractor_llm_badly_formatted_json_do_not_raise() -> None:
190191
)
191192
chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)])
192193
res = await extractor.run(chunks=chunks)
193-
assert res.nodes == []
194+
print("res.nodes", res.nodes)
195+
assert len(res.nodes) == 1
196+
assert res.nodes[0].label == "Person"
197+
assert res.nodes[0].properties == {"chunk_index": 0}
198+
assert res.nodes[0].embedding_properties is None
194199
assert res.relationships == []
195200

196201

@@ -205,6 +210,14 @@ async def test_extractor_custom_prompt() -> None:
205210
llm.ainvoke.assert_called_once_with("this is my prompt")
206211

207212

213+
def test_fix_invalid_json_empty_result() -> None:
214+
json_string = "invalid json"
215+
216+
with patch("json_repair.repair_json", return_value=""):
217+
with pytest.raises(JSONRepairError):
218+
fix_invalid_json(json_string)
219+
220+
208221
def test_fix_unquoted_keys() -> None:
209222
json_string = '{name: "John", age: "30"}'
210223
expected_result = '{"name": "John", "age": "30"}'

0 commit comments

Comments
 (0)