Skip to content

Commit 22f9aeb

Browse files
authored
Add type checking of knowledge-graph (#583)
1 parent 4853ba3 commit 22f9aeb

17 files changed

+176
-84
lines changed

.github/workflows/ci-unit-tests.yml

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ jobs:
5252
cd docker/examples/local-llm
5353
sudo docker build -t local-llm .
5454
55-
unit-tests:
56-
name: Unit Tests (Python ${{ matrix.python-version }})
55+
lint:
56+
name: Lint (Python ${{ matrix.python-version }})
5757
needs: ["preconditions"]
5858
runs-on: ubuntu-latest
5959
strategy:
@@ -73,9 +73,43 @@ jobs:
7373
python-version: "${{ matrix.python-version }}"
7474

7575
- name: Run lint
76-
if: ${{ matrix.python-version != '3.12' }}
7776
uses: ./.github/actions/lint
7877

78+
type-check:
79+
name: Type Check
80+
needs: ["preconditions"]
81+
runs-on: ubuntu-latest
82+
steps:
83+
- name: Check out the repo
84+
uses: actions/checkout@v4
85+
86+
- name: "Setup: Python 3.11"
87+
uses: ./.github/actions/setup-python
88+
89+
- name: "Type check (knowledge-graph)"
90+
run: tox -e type -c libs/knowledge-graph && rm -rf libs/knowledge-graph/.tox
91+
92+
93+
unit-tests:
94+
name: Unit Tests (Python ${{ matrix.python-version }})
95+
needs: ["preconditions"]
96+
runs-on: ubuntu-latest
97+
strategy:
98+
matrix:
99+
python-version:
100+
- "3.12"
101+
- "3.11"
102+
- "3.10"
103+
- "3.9"
104+
steps:
105+
- name: Check out the repo
106+
uses: actions/checkout@v4
107+
108+
- name: "Setup: Python ${{ matrix.python-version }}"
109+
uses: ./.github/actions/setup-python
110+
with:
111+
python-version: "${{ matrix.python-version }}"
112+
79113
- name: "Unit tests (root)"
80114
# yamllint disable-line rule:line-length
81115
if: ${{ needs.preconditions.outputs.libs_langchain == 'true' || needs.preconditions.outputs.libs_colbert == 'true' || needs.preconditions.outputs.libs_llamaindex == 'true' }}

libs/knowledge-graph/pyproject.toml

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,30 @@ pytest = "^8.1.1"
3434
pytest-asyncio = "^0.23.6"
3535
pytest-dotenv = "^0.5.2"
3636
pytest-rerunfailures = "^14.0"
37-
setuptools = "^70.0.0"
37+
mypy = "^1.10.1"
38+
types-pyyaml = "^6.0.1"
39+
pydantic = "<2" # for compatibility between LangChain and pydantic-yaml type checking
3840

3941
[build-system]
4042
requires = ["poetry-core"]
4143
build-backend = "poetry.core.masonry.api"
4244

4345
[tool.mypy]
44-
strict = true
45-
warn_unreachable = true
46-
pretty = true
47-
show_column_numbers = true
46+
disallow_any_generics = true
47+
disallow_incomplete_defs = true
48+
disallow_untyped_calls = true
49+
disallow_untyped_decorators = true
50+
disallow_untyped_defs = true
51+
follow_imports = "normal"
52+
ignore_missing_imports = true
53+
no_implicit_reexport = true
54+
show_error_codes = true
4855
show_error_context = true
56+
strict_equality = true
57+
strict_optional = true
58+
warn_redundant_casts = true
59+
warn_return_any = true
60+
warn_unused_ignores = true
4961

5062
[tool.pytest.ini_options]
5163
testpaths = ["tests"]

libs/knowledge-graph/ragstack_knowledge_graph/cassandra_graph_store.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def add_graph_documents(
5858

5959
# TODO: should this include the types of each node?
6060
@override
61-
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: # noqa: B006
61+
def query(
62+
self, query: str, params: Optional[Dict[str, Any]] = None
63+
) -> List[Dict[str, Any]]:
6264
raise ValueError("Querying Cassandra should use `as_runnable`.")
6365

6466
@override
@@ -76,7 +78,9 @@ def get_structured_schema(self) -> Dict[str, Any]:
7678
def refresh_schema(self) -> None:
7779
raise NotImplementedError
7880

79-
def as_runnable(self, steps: int = 3, edge_filters: Sequence[str] = ()) -> Runnable:
81+
def as_runnable(
82+
self, steps: int = 3, edge_filters: Sequence[str] = ()
83+
) -> Runnable[Union[Node, Sequence[Node]], Iterable[Relation]]:
8084
"""Convert to a runnable.
8185
8286
Returns a runnable that retrieves the sub-graph near the input entity or

libs/knowledge-graph/ragstack_knowledge_graph/extraction.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Dict, List, Sequence, Union, cast
1+
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union, cast
22

33
from langchain_community.graphs.graph_document import GraphDocument
44
from langchain_core.documents import Document
55
from langchain_core.language_models.chat_models import BaseChatModel
6+
from langchain_core.messages import SystemMessage
67
from langchain_core.prompts import (
78
ChatPromptTemplate,
89
HumanMessagePromptTemplate,
@@ -23,6 +24,9 @@
2324
)
2425
from .templates import load_template
2526

27+
if TYPE_CHECKING:
28+
from langchain_core.prompts.chat import MessageLikeRepresentation
29+
2630

2731
def _format_example(idx: int, example: Example) -> str:
2832
from pydantic_yaml import to_yaml_str
@@ -43,7 +47,7 @@ def __init__(
4347
self._validator = KnowledgeSchemaValidator(schema)
4448
self.strict = strict
4549

46-
messages = [
50+
messages: List[MessageLikeRepresentation] = [
4751
SystemMessagePromptTemplate(
4852
prompt=load_template(
4953
"extraction.md", knowledge_schema_yaml=schema.to_yaml_str()
@@ -52,22 +56,24 @@ def __init__(
5256
]
5357

5458
if examples:
55-
formatted = "\n\n".join(map(_format_example, examples))
56-
messages.append(SystemMessagePromptTemplate(prompt=formatted))
59+
formatted = "\n\n".join(
60+
[_format_example(i, example) for i, example in enumerate(examples)]
61+
)
62+
messages.append(SystemMessage(content=formatted))
5763

5864
messages.append(HumanMessagePromptTemplate.from_template("Input: {input}"))
5965

6066
prompt = ChatPromptTemplate.from_messages(messages)
61-
schema = create_simple_model(
67+
model_schema = create_simple_model(
6268
node_labels=[node.type for node in schema.nodes],
6369
rel_types=list({r.edge_type for r in schema.relationships}),
6470
)
6571
# TODO: Use "full" output so we can detect parsing errors?
66-
structured_llm = llm.with_structured_output(schema)
72+
structured_llm = llm.with_structured_output(model_schema)
6773
self._chain = prompt | structured_llm
6874

6975
def _process_response(
70-
self, document: Document, response: Union[Dict, BaseModel]
76+
self, document: Document, response: Union[Dict[str, Any], BaseModel]
7177
) -> GraphDocument:
7278
raw_graph = cast(_Graph, response)
7379
nodes = (
@@ -81,14 +87,14 @@ def _process_response(
8187
else []
8288
)
8389

84-
document = GraphDocument(
90+
graph_document = GraphDocument(
8591
nodes=nodes, relationships=relationships, source=document
8692
)
8793

8894
if self.strict:
89-
self._validator.validate_graph_document(document)
95+
self._validator.validate_graph_document(graph_document)
9096

91-
return document
97+
return graph_document
9298

9399
def extract(self, documents: List[Document]) -> List[GraphDocument]:
94100
"""Extract knowledge graphs from a list of documents."""

libs/knowledge-graph/ragstack_knowledge_graph/knowledge_graph.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def _deserialize_md_dict(md_string: str) -> Dict[str, Any]:
1919
return cast(Dict[str, Any], json.loads(md_string))
2020

2121

22-
def _parse_node(row) -> Node:
22+
def _parse_node(row: Any) -> Node:
2323
return Node(
2424
name=row.name,
2525
type=row.type,
@@ -106,7 +106,7 @@ def __init__(
106106
"""
107107
)
108108

109-
def _apply_schema(self):
109+
def _apply_schema(self) -> None:
110110
# Partition by `name` and cluster by `type`.
111111
# Each `(name, type)` pair is a unique node.
112112
# We can enumerate all `type` values for a given `name` to identify ambiguous
@@ -152,11 +152,13 @@ def _apply_schema(self):
152152
"""
153153
)
154154

155-
def _send_query_nearest_node(self, node: str, k: int = 1) -> ResponseFuture:
155+
def _send_query_nearest_node(
156+
self, embeddings: Embeddings, node: str, k: int = 1
157+
) -> ResponseFuture:
156158
return self._session.execute_async(
157159
self._query_nodes_by_embedding,
158160
(
159-
self._text_embeddings.embed_query(node),
161+
embeddings.embed_query(node),
160162
k,
161163
),
162164
)
@@ -173,13 +175,11 @@ def query_nearest_nodes(self, nodes: Iterable[str], k: int = 1) -> Iterable[Node
173175
raise ValueError("Unable to query for nearest nodes without embeddings")
174176

175177
node_futures: Iterable[ResponseFuture] = [
176-
self._send_query_nearest_node(n, k) for n in nodes
178+
self._send_query_nearest_node(self._text_embeddings, n, k) for n in nodes
177179
]
178-
179-
nodes = {
180+
return {
180181
_parse_node(n) for node_future in node_futures for n in node_future.result()
181182
}
182-
return list(nodes)
183183

184184
# TODO: Introduce `ainsert` for async insertions.
185185
def insert(
@@ -253,9 +253,11 @@ def subgraph(
253253
for n in nodes
254254
]
255255

256-
nodes = [_parse_node(n) for future in node_futures for n in future.result()]
256+
graph_nodes = [
257+
_parse_node(n) for future in node_futures for n in future.result()
258+
]
257259

258-
return nodes, edges
260+
return graph_nodes, edges
259261

260262
def traverse(
261263
self,

libs/knowledge-graph/ragstack_knowledge_graph/knowledge_schema.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Dict, List, Sequence, Union
2+
from typing import Dict, List, Self, Sequence, Union
33

44
from langchain_community.graphs.graph_document import GraphDocument
55
from langchain_core.pydantic_v1 import BaseModel
@@ -66,7 +66,7 @@ class KnowledgeSchema(BaseModel):
6666
"""Allowed relationships for the knowledge schema."""
6767

6868
@classmethod
69-
def from_file(cls, path: Union[str, Path]) -> "KnowledgeSchema":
69+
def from_file(cls, path: Union[str, Path]) -> Self:
7070
"""Load a KnowledgeSchema from a JSON or YAML file.
7171
7272
Args:
@@ -98,27 +98,27 @@ def __init__(self, schema: KnowledgeSchema) -> None:
9898
# TODO: Validate the relationship.
9999
# source/target type should exist in nodes, edge_type should exist in edges
100100

101-
def validate_graph_document(self, document: GraphDocument):
101+
def validate_graph_document(self, document: GraphDocument) -> None:
102102
"""Validate a graph document against the schema."""
103103
e = ValueError("Invalid graph document for schema")
104104
for node_type in {node.type for node in document.nodes}:
105105
if node_type not in self._nodes:
106106
e.add_note(f"No node type '{node_type}")
107107
for r in document.relationships:
108-
relationships = self._relationships.get(r.edge_type, None)
108+
relationships = self._relationships.get(r.type, None)
109109
if relationships is None:
110-
e.add_note(f"No edge type '{r.edge_type}")
110+
e.add_note(f"No edge type '{r.type}")
111111
else:
112112
relationship = next(
113113
candidate
114114
for candidate in relationships
115-
if r.source_type in candidate.source_types
116-
if r.target_type in candidate.target_types
115+
if r.source.type in candidate.source_types
116+
if r.target.type in candidate.target_types
117117
)
118118
if relationship is None:
119119
e.add_note(
120120
"No relationship allows "
121-
f"({r.source_id} -> {r.type} -> {r.target.type})"
121+
f"({r.source.id} -> {r.type} -> {r.target.type})"
122122
)
123123

124124
if e.__notes__:

libs/knowledge-graph/ragstack_knowledge_graph/render.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Iterable, Union
1+
from typing import Dict, Iterable, Tuple, Union
22

33
import graphviz
44
from langchain_community.graphs.graph_document import GraphDocument, Node
@@ -12,7 +12,7 @@ def _node_label(node: Node) -> str:
1212

1313
def print_graph_documents(
1414
graph_documents: Union[GraphDocument, Iterable[GraphDocument]],
15-
):
15+
) -> None:
1616
"""Prints the relationships in the graph documents."""
1717
if isinstance(graph_documents, GraphDocument):
1818
graph_documents = [graph_documents]
@@ -29,13 +29,13 @@ def render_graph_documents(
2929
) -> graphviz.Digraph:
3030
"""Renders the relationships in the graph documents."""
3131
if isinstance(graph_documents, GraphDocument):
32-
graph_documents = [GraphDocument]
32+
graph_documents = [graph_documents]
3333

3434
dot = graphviz.Digraph()
3535

36-
nodes = {}
36+
nodes: Dict[Tuple[Union[str, int], str], str] = {}
3737

38-
def _node_id(node: Node) -> int:
38+
def _node_id(node: Node) -> str:
3939
node_key = (node.id, node.type)
4040
if node_id := nodes.get(node_key):
4141
return node_id

libs/knowledge-graph/ragstack_knowledge_graph/runnables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import Any, Dict, List, Optional
22

33
from langchain_core.language_models import BaseChatModel
44
from langchain_core.output_parsers import JsonOutputParser
@@ -26,7 +26,7 @@ def extract_entities(
2626
llm: BaseChatModel,
2727
keyword_extraction_prompt: str = QUERY_ENTITY_EXTRACT_PROMPT,
2828
node_types: Optional[List[str]] = None,
29-
) -> Runnable:
29+
) -> Runnable[Dict[str, Any], List[Node]]:
3030
"""Return a keyword-extraction runnable.
3131
3232
This will expect a dictionary containing the `"question"` to extract keywords from.

libs/knowledge-graph/ragstack_knowledge_graph/templates.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from os import path
2-
from typing import Callable, Union
2+
from typing import Callable, Union, cast
33

44
from langchain_core.prompts import PromptTemplate
55

@@ -12,5 +12,5 @@ def load_template(
1212
"""Load a template from a file."""
1313
template = PromptTemplate.from_file(path.join(TEMPLATE_PATH, filename))
1414
if kwargs:
15-
template = template.partial(**kwargs)
15+
template = cast(PromptTemplate, template.partial(**kwargs))
1616
return template

0 commit comments

Comments
 (0)