Skip to content

Commit 5ff2d3c

Browse files
darthtrevinoAlonsoGuevarajgbradley1
authored
Remove graphrag.llm, replace with fnllm (#1315)
* add fnllm; remove llm folder * remove llm unit tests * update imports * update imports * formatting * enable autosave * update mockllm * update community reports extractor * move most llm usage to fnllm * update type issues * fix unit tests * type updates * update dictionary * semver * update llm construction, get integration tests working * load from llmparameters model * move ruff settings to ruff.toml * add gitattributes file * ignore ruff.toml spelling * update .gitattributes * update gitignore * update config construction * update prompt var usage * add cache adapter * use cache adapter in embeddings calls * update embedding strategy * add fnllm * add pytest-dotenv * fix some verb tests * get verbtests running * update ruff.toml for vscode * enable ruff native server in vscode * update artifact inspecting code * remove local-test update * use string.replace instead of string.format in community reprots etxractor * bump timeout * revert ruff.toml, vscode settings for another pr * revert cspell config * revert gitignore * remove json-repair, update fnllm * use fnllm generic type interfaces * update load_llm to use target models * consolidate chat parameters * add 'extra_attributes' prop to community report response * formatting * update fnllm * formatting * formatting * Add defaults to some llm params to avoid null on params hash * Formatting --------- Co-authored-by: Alonso Guevara <[email protected]> Co-authored-by: Josh Bradley <[email protected]>
1 parent d43124e commit 5ff2d3c

File tree

77 files changed

+670
-2747
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+670
-2747
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "replace llm package with fnllm"
4+
}

dictionary.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ pypi
6868
nbformat
6969
semversioner
7070
mkdocs
71+
fnllm
7172
typer
7273

7374
# Library Methods

graphrag/api/prompt_tune.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,9 @@ async def generate_indexing_prompts(
9797
# Create LLM from config
9898
llm = load_llm(
9999
"prompt_tuning",
100-
config.llm.type,
101-
NoopVerbCallbacks(),
102-
None,
103-
config.llm.model_dump(),
100+
config.llm,
101+
cache=None,
102+
callbacks=NoopVerbCallbacks(),
104103
)
105104

106105
if not domain:

graphrag/config/create_graphrag_config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -702,9 +702,7 @@ class Section(str, Enum):
702702

703703
def _is_azure(llm_type: LLMType | None) -> bool:
704704
return (
705-
llm_type == LLMType.AzureOpenAI
706-
or llm_type == LLMType.AzureOpenAIChat
707-
or llm_type == LLMType.AzureOpenAIEmbedding
705+
llm_type == LLMType.AzureOpenAIChat or llm_type == LLMType.AzureOpenAIEmbedding
708706
)
709707

710708

graphrag/config/defaults.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020

2121
ASYNC_MODE = AsyncType.Threaded
2222
ENCODING_MODEL = "cl100k_base"
23+
AZURE_AUDIENCE = "https://cognitiveservices.azure.com/.default"
2324
#
2425
# LLM Parameters
2526
#
27+
LLM_FREQUENCY_PENALTY = 0.0
2628
LLM_TYPE = LLMType.OpenAIChat
2729
LLM_MODEL = "gpt-4-turbo-preview"
2830
LLM_MAX_TOKENS = 4000
@@ -34,6 +36,7 @@
3436
LLM_REQUESTS_PER_MINUTE = 0
3537
LLM_MAX_RETRIES = 10
3638
LLM_MAX_RETRY_WAIT = 10.0
39+
LLM_PRESENCE_PENALTY = 0.0
3740
LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION = True
3841
LLM_CONCURRENT_REQUESTS = 25
3942

graphrag/config/enums.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,6 @@ class LLMType(str, Enum):
100100
OpenAIEmbedding = "openai_embedding"
101101
AzureOpenAIEmbedding = "azure_openai_embedding"
102102

103-
# Raw Completion
104-
OpenAI = "openai"
105-
AzureOpenAI = "azure_openai"
106-
107103
# Chat Completion
108104
OpenAIChat = "openai_chat"
109105
AzureOpenAIChat = "azure_openai_chat"

graphrag/config/models/llm_parameters.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,13 @@ class LLMParameters(BaseModel):
2020
type: LLMType = Field(
2121
description="The type of LLM model to use.", default=defs.LLM_TYPE
2222
)
23+
encoding_model: str | None = Field(
24+
description="The encoding model to use", default=defs.ENCODING_MODEL
25+
)
2326
model: str = Field(description="The LLM model to use.", default=defs.LLM_MODEL)
27+
embeddings_model: str | None = Field(
28+
description="The embeddings model to use.", default=defs.EMBEDDING_MODEL
29+
)
2430
max_tokens: int | None = Field(
2531
description="The maximum number of tokens to generate.",
2632
default=defs.LLM_MAX_TOKENS,
@@ -37,6 +43,14 @@ class LLMParameters(BaseModel):
3743
description="The number of completions to generate.",
3844
default=defs.LLM_N,
3945
)
46+
frequency_penalty: float | None = Field(
47+
description="The frequency penalty to use for token generation.",
48+
default=defs.LLM_FREQUENCY_PENALTY,
49+
)
50+
presence_penalty: float | None = Field(
51+
description="The presence penalty to use for token generation.",
52+
default=defs.LLM_PRESENCE_PENALTY,
53+
)
4054
request_timeout: float = Field(
4155
description="The request timeout to use.", default=defs.LLM_REQUEST_TIMEOUT
4256
)
@@ -86,3 +100,6 @@ class LLMParameters(BaseModel):
86100
description="Whether to use concurrent requests for the LLM service.",
87101
default=defs.LLM_CONCURRENT_REQUESTS,
88102
)
103+
responses: list[str | BaseModel] | None = Field(
104+
default=None, description="Static responses to use in mock mode."
105+
)

graphrag/index/graph/extractors/claims/claim_extractor.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from typing import Any
1010

1111
import tiktoken
12+
from fnllm import ChatLLM
1213

1314
import graphrag.config.defaults as defs
1415
from graphrag.index.typing import ErrorHandlerFn
15-
from graphrag.llm import CompletionLLM
1616
from graphrag.prompts.index.claim_extraction import (
1717
CLAIM_EXTRACTION_PROMPT,
1818
CONTINUE_PROMPT,
@@ -36,7 +36,7 @@ class ClaimExtractorResult:
3636
class ClaimExtractor:
3737
"""Claim extractor class definition."""
3838

39-
_llm: CompletionLLM
39+
_llm: ChatLLM
4040
_extraction_prompt: str
4141
_summary_prompt: str
4242
_output_formatter_prompt: str
@@ -48,10 +48,11 @@ class ClaimExtractor:
4848
_completion_delimiter_key: str
4949
_max_gleanings: int
5050
_on_error: ErrorHandlerFn
51+
_loop_args: dict[str, Any]
5152

5253
def __init__(
5354
self,
54-
llm_invoker: CompletionLLM,
55+
llm_invoker: ChatLLM,
5556
extraction_prompt: str | None = None,
5657
input_text_key: str | None = None,
5758
input_entity_spec_key: str | None = None,
@@ -87,9 +88,9 @@ def __init__(
8788

8889
# Construct the looping arguments
8990
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
90-
yes = encoding.encode("YES")
91-
no = encoding.encode("NO")
92-
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
91+
yes = f"{encoding.encode('YES')[0]}"
92+
no = f"{encoding.encode('NO')[0]}"
93+
self._loop_args = {"logit_bias": {yes: 100, no: 100}, "max_tokens": 1}
9394

9495
async def __call__(
9596
self, inputs: dict[str, Any], prompt_variables: dict | None = None
@@ -164,13 +165,12 @@ async def _process_document(
164165
)
165166

166167
response = await self._llm(
167-
self._extraction_prompt,
168-
variables={
168+
self._extraction_prompt.format(**{
169169
self._input_text_key: doc,
170170
**prompt_args,
171-
},
171+
})
172172
)
173-
results = response.output or ""
173+
results = response.output.content or ""
174174
claims = results.strip().removesuffix(completion_delimiter)
175175

176176
# Repeat to ensure we maximize entity count
@@ -180,7 +180,7 @@ async def _process_document(
180180
name=f"extract-continuation-{i}",
181181
history=response.history,
182182
)
183-
extension = response.output or ""
183+
extension = response.output.content or ""
184184
claims += record_delimiter + extension.strip().removesuffix(
185185
completion_delimiter
186186
)
@@ -195,7 +195,7 @@ async def _process_document(
195195
history=response.history,
196196
model_parameters=self._loop_args,
197197
)
198-
if response.output != "YES":
198+
if response.output.content != "YES":
199199
break
200200

201201
return self._parse_claim_tuples(results, prompt_args)

graphrag/index/graph/extractors/community_reports/community_reports_extractor.py

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,50 @@
88
from dataclasses import dataclass
99
from typing import Any
1010

11+
from fnllm import ChatLLM
12+
from pydantic import BaseModel, Field
13+
1114
from graphrag.index.typing import ErrorHandlerFn
12-
from graphrag.index.utils.dicts import dict_has_keys_with_types
13-
from graphrag.llm import CompletionLLM
1415
from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT
1516

1617
log = logging.getLogger(__name__)
1718

1819

20+
class FindingModel(BaseModel):
21+
"""A model for the expected LLM response shape."""
22+
23+
summary: str = Field(description="The summary of the finding.")
24+
explanation: str = Field(description="An explanation of the finding.")
25+
26+
27+
class CommunityReportResponse(BaseModel):
28+
"""A model for the expected LLM response shape."""
29+
30+
title: str = Field(description="The title of the report.")
31+
summary: str = Field(description="A summary of the report.")
32+
findings: list[FindingModel] = Field(
33+
description="A list of findings in the report."
34+
)
35+
rating: float = Field(description="The rating of the report.")
36+
rating_explanation: str = Field(description="An explanation of the rating.")
37+
38+
extra_attributes: dict[str, Any] = Field(
39+
default_factory=dict, description="Extra attributes."
40+
)
41+
42+
1943
@dataclass
2044
class CommunityReportsResult:
2145
"""Community reports result class definition."""
2246

2347
output: str
24-
structured_output: dict
48+
structured_output: CommunityReportResponse | None
2549

2650

2751
class CommunityReportsExtractor:
2852
"""Community reports extractor class definition."""
2953

30-
_llm: CompletionLLM
54+
_llm: ChatLLM
3155
_input_text_key: str
3256
_extraction_prompt: str
3357
_output_formatter_prompt: str
@@ -36,7 +60,7 @@ class CommunityReportsExtractor:
3660

3761
def __init__(
3862
self,
39-
llm_invoker: CompletionLLM,
63+
llm_invoker: ChatLLM,
4064
input_text_key: str | None = None,
4165
extraction_prompt: str | None = None,
4266
on_error: ErrorHandlerFn | None = None,
@@ -53,55 +77,30 @@ async def __call__(self, inputs: dict[str, Any]):
5377
"""Call method definition."""
5478
output = None
5579
try:
56-
response = (
57-
await self._llm(
58-
self._extraction_prompt,
59-
json=True,
60-
name="create_community_report",
61-
variables={self._input_text_key: inputs[self._input_text_key]},
62-
is_response_valid=lambda x: dict_has_keys_with_types(
63-
x,
64-
[
65-
("title", str),
66-
("summary", str),
67-
("findings", list),
68-
("rating", float),
69-
("rating_explanation", str),
70-
],
71-
inplace=True,
72-
),
73-
model_parameters={"max_tokens": self._max_report_length},
74-
)
75-
or {}
80+
input_text = inputs[self._input_text_key]
81+
prompt = self._extraction_prompt.replace(
82+
"{" + self._input_text_key + "}", input_text
83+
)
84+
response = await self._llm(
85+
prompt,
86+
json=True,
87+
name="create_community_report",
88+
json_model=CommunityReportResponse,
89+
model_parameters={"max_tokens": self._max_report_length},
7690
)
77-
output = response.json or {}
91+
output = response.parsed_json
7892
except Exception as e:
7993
log.exception("error generating community report")
8094
self._on_error(e, traceback.format_exc(), None)
81-
output = {}
8295

83-
text_output = self._get_text_output(output)
96+
text_output = self._get_text_output(output) if output else ""
8497
return CommunityReportsResult(
8598
structured_output=output,
8699
output=text_output,
87100
)
88101

89-
def _get_text_output(self, parsed_output: dict) -> str:
90-
title = parsed_output.get("title", "Report")
91-
summary = parsed_output.get("summary", "")
92-
findings = parsed_output.get("findings", [])
93-
94-
def finding_summary(finding: dict):
95-
if isinstance(finding, str):
96-
return finding
97-
return finding.get("summary")
98-
99-
def finding_explanation(finding: dict):
100-
if isinstance(finding, str):
101-
return ""
102-
return finding.get("explanation")
103-
102+
def _get_text_output(self, report: CommunityReportResponse) -> str:
104103
report_sections = "\n\n".join(
105-
f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
104+
f"## {f.summary}\n\n{f.explanation}" for f in report.findings
106105
)
107-
return f"# {title}\n\n{summary}\n\n{report_sections}"
106+
return f"# {report.title}\n\n{report.summary}\n\n{report_sections}"

graphrag/index/graph/extractors/graph/graph_extractor.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212

1313
import networkx as nx
1414
import tiktoken
15+
from fnllm import ChatLLM
1516

1617
import graphrag.config.defaults as defs
1718
from graphrag.index.typing import ErrorHandlerFn
1819
from graphrag.index.utils.string import clean_str
19-
from graphrag.llm import CompletionLLM
2020
from graphrag.prompts.index.entity_extraction import (
2121
CONTINUE_PROMPT,
2222
GRAPH_EXTRACTION_PROMPT,
@@ -40,7 +40,7 @@ class GraphExtractionResult:
4040
class GraphExtractor:
4141
"""Unipartite graph extractor class definition."""
4242

43-
_llm: CompletionLLM
43+
_llm: ChatLLM
4444
_join_descriptions: bool
4545
_tuple_delimiter_key: str
4646
_record_delimiter_key: str
@@ -57,7 +57,7 @@ class GraphExtractor:
5757

5858
def __init__(
5959
self,
60-
llm_invoker: CompletionLLM,
60+
llm_invoker: ChatLLM,
6161
tuple_delimiter_key: str | None = None,
6262
record_delimiter_key: str | None = None,
6363
input_text_key: str | None = None,
@@ -90,9 +90,9 @@ def __init__(
9090

9191
# Construct the looping arguments
9292
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
93-
yes = encoding.encode("YES")
94-
no = encoding.encode("NO")
95-
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
93+
yes = f"{encoding.encode('YES')[0]}"
94+
no = f"{encoding.encode('NO')[0]}"
95+
self._loop_args = {"logit_bias": {yes: 100, no: 100}, "max_tokens": 1}
9696

9797
async def __call__(
9898
self, texts: list[str], prompt_variables: dict[str, Any] | None = None
@@ -151,13 +151,12 @@ async def _process_document(
151151
self, text: str, prompt_variables: dict[str, str]
152152
) -> str:
153153
response = await self._llm(
154-
self._extraction_prompt,
155-
variables={
154+
self._extraction_prompt.format(**{
156155
**prompt_variables,
157156
self._input_text_key: text,
158-
},
157+
}),
159158
)
160-
results = response.output or ""
159+
results = response.output.content or ""
161160

162161
# Repeat to ensure we maximize entity count
163162
for i in range(self._max_gleanings):
@@ -166,7 +165,7 @@ async def _process_document(
166165
name=f"extract-continuation-{i}",
167166
history=response.history,
168167
)
169-
results += response.output or ""
168+
results += response.output.content or ""
170169

171170
# if this is the final glean, don't bother updating the continuation flag
172171
if i >= self._max_gleanings - 1:

0 commit comments

Comments
 (0)