Skip to content

Commit a76a72c

Browse files
authored
Remove dependency on precisely (#449)
1 parent 221c4a8 commit a76a72c

File tree

9 files changed

+160
-70
lines changed

9 files changed

+160
-70
lines changed

libs/knowledge-graph/pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ testcontainers = "~3.7.1"
3333
# https://github.com/psf/requests/issues/6707
3434
requests = "<=2.31.0"
3535
pytest = "^8.1.1"
36-
precisely = "^0.1.9"
3736
pytest-asyncio = "^0.23.6"
3837
pytest-dotenv = "^0.5.2"
3938
setuptools = "^70.0.0"

libs/knowledge-graph/tests/conftest.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
from cassandra.cluster import Cluster, Session
6+
from dotenv import load_dotenv
67
from langchain.graphs.graph_document import GraphDocument, Node, Relationship
78
from langchain_core.documents import Document
89
from langchain_core.language_models import BaseChatModel
@@ -11,6 +12,8 @@
1112

1213
from ragstack_knowledge_graph.cassandra_graph_store import CassandraGraphStore
1314

15+
load_dotenv()
16+
1417

1518
@pytest.fixture(scope="session")
1619
def db_keyspace() -> str:
@@ -65,7 +68,9 @@ def llm() -> BaseChatModel:
6568

6669

6770
class DataFixture:
68-
def __init__(self, session: Session, keyspace: str, documents: List[GraphDocument]) -> None:
71+
def __init__(
72+
self, session: Session, keyspace: str, documents: List[GraphDocument]
73+
) -> None:
6974
self.session = session
7075
self.keyspace = "default_keyspace"
7176
self.uid = secrets.token_hex(8)
@@ -130,7 +135,9 @@ def marie_curie(db_session: Session, db_keyspace: str) -> Iterator[DataFixture]:
130135
Relationship(source=marie_curie, target=nobel_prize, type="WON"),
131136
Relationship(source=pierre_curie, target=nobel_prize, type="WON"),
132137
Relationship(source=marie_curie, target=pierre_curie, type="MARRIED_TO"),
133-
Relationship(source=marie_curie, target=university_of_paris, type="WORKED_AT"),
138+
Relationship(
139+
source=marie_curie, target=university_of_paris, type="WORKED_AT"
140+
),
134141
Relationship(source=marie_curie, target=professor, type="HAS_PROFESSION"),
135142
],
136143
source=Document(page_content="test_content"),

libs/knowledge-graph/tests/test_extraction.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from langchain_community.graphs.graph_document import Node, Relationship
55
from langchain_core.documents import Document
66
from langchain_core.language_models import BaseChatModel
7-
from precisely import assert_that, contains_exactly
87

98
from ragstack_knowledge_graph.extraction import (
109
KnowledgeSchema,
@@ -34,6 +33,7 @@ def extractor(llm: BaseChatModel) -> KnowledgeSchemaExtractor:
3433
Paris.
3534
"""
3635

36+
3737
@pytest.mark.flaky(retries=10, delay=0)
3838
def test_extraction(extractor: KnowledgeSchemaExtractor):
3939
results = extractor.extract([Document(page_content=MARIE_CURIE_SOURCE)])
@@ -50,9 +50,8 @@ def test_extraction(extractor: KnowledgeSchemaExtractor):
5050
# putting things into standard title case, etc.
5151
university_of_paris = Node(id="University Of Paris", type="Institution")
5252

53-
assert_that(
54-
results[0].nodes,
55-
contains_exactly(
53+
assert sorted(results[0].nodes, key=lambda x: x.id) == sorted(
54+
[
5655
marie_curie,
5756
polish,
5857
french,
@@ -61,18 +60,24 @@ def test_extraction(extractor: KnowledgeSchemaExtractor):
6160
nobel_prize,
6261
pierre_curie,
6362
university_of_paris,
64-
),
63+
],
64+
key=lambda x: x.id,
6565
)
66-
assert_that(
67-
results[0].relationships,
68-
contains_exactly(
66+
67+
assert sorted(
68+
results[0].relationships, key=lambda x: (x.source.id, x.target.id, x.type)
69+
) == sorted(
70+
[
6971
Relationship(source=marie_curie, target=polish, type="HAS_NATIONALITY"),
7072
Relationship(source=marie_curie, target=french, type="HAS_NATIONALITY"),
7173
Relationship(source=marie_curie, target=physicist, type="HAS_OCCUPATION"),
7274
Relationship(source=marie_curie, target=chemist, type="HAS_OCCUPATION"),
7375
Relationship(source=marie_curie, target=nobel_prize, type="RECEIVED"),
7476
Relationship(source=pierre_curie, target=nobel_prize, type="RECEIVED"),
75-
Relationship(source=marie_curie, target=university_of_paris, type="WORKED_AT"),
77+
Relationship(
78+
source=marie_curie, target=university_of_paris, type="WORKED_AT"
79+
),
7680
Relationship(source=marie_curie, target=pierre_curie, type="MARRIED_TO"),
77-
),
81+
],
82+
key=lambda x: (x.source.id, x.target.id, x.type),
7883
)

libs/knowledge-graph/tests/test_knowledge_graph.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import secrets
22
import pytest
3-
from precisely import assert_that, contains_exactly
43

54
from cassandra.cluster import Session
65
from ragstack_knowledge_graph.knowledge_graph import CassandraKnowledgeGraph
76
from ragstack_knowledge_graph.traverse import Node, Relation
87

98
from .conftest import DataFixture
109

10+
1111
def test_no_embeddings(db_session: Session, db_keyspace: str) -> None:
1212
uid = secrets.token_hex(8)
1313
node_table = f"entities_{uid}"
@@ -22,6 +22,7 @@ def test_no_embeddings(db_session: Session, db_keyspace: str) -> None:
2222
)
2323
graph.insert([Node(name="a", type="b")])
2424

25+
2526
def test_traverse_marie_curie(marie_curie: DataFixture) -> None:
2627
(result_nodes, result_edges) = marie_curie.graph_store.graph.subgraph(
2728
start=Node("Marie Curie", "Person"),
@@ -40,45 +41,67 @@ def test_traverse_marie_curie(marie_curie: DataFixture) -> None:
4041
Node(name="Professor", type="Profession"),
4142
]
4243
expected_edges = {
43-
Relation(Node("Marie Curie", "Person"), Node("Polish", "Nationality"), "HAS_NATIONALITY"),
44-
Relation(Node("Marie Curie", "Person"), Node("French", "Nationality"), "HAS_NATIONALITY"),
4544
Relation(
46-
Node("Marie Curie", "Person"), Node("Physicist", "Profession"), "HAS_PROFESSION"
45+
Node("Marie Curie", "Person"),
46+
Node("Polish", "Nationality"),
47+
"HAS_NATIONALITY",
48+
),
49+
Relation(
50+
Node("Marie Curie", "Person"),
51+
Node("French", "Nationality"),
52+
"HAS_NATIONALITY",
53+
),
54+
Relation(
55+
Node("Marie Curie", "Person"),
56+
Node("Physicist", "Profession"),
57+
"HAS_PROFESSION",
58+
),
59+
Relation(
60+
Node("Marie Curie", "Person"),
61+
Node("Chemist", "Profession"),
62+
"HAS_PROFESSION",
4763
),
48-
Relation(Node("Marie Curie", "Person"), Node("Chemist", "Profession"), "HAS_PROFESSION"),
4964
Relation(
50-
Node("Marie Curie", "Person"), Node("Professor", "Profession"), "HAS_PROFESSION"
65+
Node("Marie Curie", "Person"),
66+
Node("Professor", "Profession"),
67+
"HAS_PROFESSION",
5168
),
5269
Relation(
5370
Node("Marie Curie", "Person"),
5471
Node("Radioactivity", "Scientific concept"),
5572
"RESEARCHED",
5673
),
5774
Relation(Node("Marie Curie", "Person"), Node("Nobel Prize", "Award"), "WON"),
58-
Relation(Node("Marie Curie", "Person"), Node("Pierre Curie", "Person"), "MARRIED_TO"),
75+
Relation(
76+
Node("Marie Curie", "Person"), Node("Pierre Curie", "Person"), "MARRIED_TO"
77+
),
5978
Relation(
6079
Node("Marie Curie", "Person"),
6180
Node("University of Paris", "Organization"),
6281
"WORKED_AT",
6382
),
6483
}
65-
assert_that(result_edges, contains_exactly(*expected_edges))
66-
assert_that(result_nodes, contains_exactly(*expected_nodes))
84+
assert sorted(result_edges) == sorted(expected_edges)
85+
assert sorted(result_nodes) == sorted(expected_nodes)
6786

6887

6988
def test_fuzzy_search(marie_curie: DataFixture) -> None:
7089
if not marie_curie.has_embeddings:
71-
pytest.skip("Fuzzy search requires embeddings. Run with openai environment variables")
72-
result_nodes = marie_curie.graph_store.graph.query_nearest_nodes(["Marie", "Poland"])
90+
pytest.skip(
91+
"Fuzzy search requires embeddings. Run with openai environment variables"
92+
)
93+
result_nodes = marie_curie.graph_store.graph.query_nearest_nodes(
94+
["Marie", "Poland"]
95+
)
7396
expected_nodes = [
7497
Node(name="Marie Curie", type="Person"),
7598
Node(name="Polish", type="Nationality", properties={"European": True}),
7699
]
77-
assert_that(result_nodes, contains_exactly(*expected_nodes))
100+
assert sorted(result_nodes) == sorted(expected_nodes)
78101

79102
result_nodes = marie_curie.graph_store.graph.query_nearest_nodes(["European"], k=2)
80103
expected_nodes = [
81104
Node(name="Polish", type="Nationality", properties={"European": True}),
82105
Node(name="French", type="Nationality", properties={"European": True}),
83106
]
84-
assert_that(result_nodes, contains_exactly(*expected_nodes))
107+
assert sorted(result_nodes) == sorted(expected_nodes)
Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
from precisely import assert_that, contains_exactly
2-
31
from ragstack_knowledge_graph.runnables import extract_entities
42
from ragstack_knowledge_graph.traverse import Node
53

64

75
def test_extract_entities(llm):
86
extractor = extract_entities(llm)
9-
assert_that(
10-
extractor.invoke({"question": "Who is Marie Curie?"}),
11-
contains_exactly(Node("Marie Curie", "Person")),
12-
)
7+
assert extractor.invoke({"question": "Who is Marie Curie?"}) == [
8+
Node("Marie Curie", "Person")
9+
]

0 commit comments

Comments
 (0)