Skip to content

Commit 9b0f20d

Browse files
ds-filipknefelFilip Knefelmpolomdeepsense
authored
🔀 fix: change enums usage to work on all supported python versions (#329)
* Change enum usage to work on all supported python versions * Fix flake8 error * Version and changelog update; neo4j string enum fix * Enable python 3.11 and 3.12 for unit_tests workflow * Fix python 3.12 flake8 check * Fix nltk download * Enable python 3.11 and 3.12 for src and dst integration tests * Revert "Enable python 3.11 and 3.12 for src and dst integration tests" This reverts commit fc40dbc. * Revert "Fix python 3.12 flake8 check" This reverts commit 709647e. * Revert "Enable python 3.11 and 3.12 for unit_tests workflow" This reverts commit bbad641. * Fix neo4j tests * Revert "Fix neo4j tests" This reverts commit 7de281c. --------- Co-authored-by: Filip Knefel <[email protected]> Co-authored-by: Marek PoÅ‚om <[email protected]>
1 parent f27c71c commit 9b0f20d

File tree

4 files changed

+32
-18
lines changed

4 files changed

+32
-18
lines changed

‎CHANGELOG.md‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
## 0.3.14-dev0
2+
3+
### Fixes
4+
5+
* **Fix Neo4j Uploader string enum error**
6+
17
## 0.3.13
28

39
### Fixes

‎test/integration/connectors/test_neo4j.py‎

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,15 @@ async def validate_uploaded_graph(upload_file: Path):
199199
try:
200200
nodes_count = len((await driver.execute_query("MATCH (n) RETURN n"))[0])
201201
chunk_nodes_count = len(
202-
(await driver.execute_query(f"MATCH (n: {Label.CHUNK}) RETURN n"))[0]
202+
(await driver.execute_query(f"MATCH (n: {Label.CHUNK.value}) RETURN n"))[0]
203203
)
204204
document_nodes_count = len(
205-
(await driver.execute_query(f"MATCH (n: {Label.DOCUMENT}) RETURN n"))[0]
205+
(await driver.execute_query(f"MATCH (n: {Label.DOCUMENT.value}) RETURN n"))[0]
206206
)
207207
element_nodes_count = len(
208-
(await driver.execute_query(f"MATCH (n: {Label.UNSTRUCTURED_ELEMENT}) RETURN n"))[0]
208+
(await driver.execute_query(f"MATCH (n: {Label.UNSTRUCTURED_ELEMENT.value}) RETURN n"))[
209+
0
210+
]
209211
)
210212
with check:
211213
assert nodes_count == expected_nodes_count
@@ -217,12 +219,18 @@ async def validate_uploaded_graph(upload_file: Path):
217219
assert element_nodes_count == expected_element_count
218220

219221
records, _, _ = await driver.execute_query(
220-
f"MATCH ()-[r:{Relationship.PART_OF_DOCUMENT}]->(:{Label.DOCUMENT}) RETURN r"
222+
f"""
223+
MATCH ()-[r:{Relationship.PART_OF_DOCUMENT.value}]->(:{Label.DOCUMENT.value})
224+
RETURN r
225+
"""
221226
)
222227
part_of_document_count = len(records)
223228

224229
records, _, _ = await driver.execute_query(
225-
f"MATCH (:{Label.CHUNK})-[r:{Relationship.NEXT_CHUNK}]->(:{Label.CHUNK}) RETURN r"
230+
f"""
231+
MATCH (:{Label.CHUNK.value})-[r:{Relationship.NEXT_CHUNK.value}]->(:{Label.CHUNK.value})
232+
RETURN r
233+
"""
226234
)
227235
next_chunk_count = len(records)
228236

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.3.13" # pragma: no cover
1+
__version__ = "0.3.14-dev0" # pragma: no cover

‎unstructured_ingest/v2/processes/connectors/neo4j.py‎

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def run( # type: ignore
105105
output_filepath.parent.mkdir(parents=True, exist_ok=True)
106106

107107
with open(output_filepath, "w") as file:
108-
json.dump(_GraphData.from_nx(nx_graph).model_dump(), file, indent=4)
108+
file.write(_GraphData.from_nx(nx_graph).model_dump_json())
109109

110110
return output_filepath
111111

@@ -196,7 +196,7 @@ def from_nx(cls, nx_graph: "MultiDiGraph") -> _GraphData:
196196

197197

198198
class _Node(BaseModel):
199-
model_config = ConfigDict(use_enum_values=True)
199+
model_config = ConfigDict()
200200

201201
id_: str = Field(default_factory=lambda: str(uuid.uuid4()))
202202
labels: list[Label] = Field(default_factory=list)
@@ -207,20 +207,20 @@ def __hash__(self):
207207

208208

209209
class _Edge(BaseModel):
210-
model_config = ConfigDict(use_enum_values=True)
210+
model_config = ConfigDict()
211211

212212
source_id: str
213213
destination_id: str
214214
relationship: Relationship
215215

216216

217-
class Label(str, Enum):
217+
class Label(Enum):
218218
UNSTRUCTURED_ELEMENT = "UnstructuredElement"
219219
CHUNK = "Chunk"
220220
DOCUMENT = "Document"
221221

222222

223-
class Relationship(str, Enum):
223+
class Relationship(Enum):
224224
PART_OF_DOCUMENT = "PART_OF_DOCUMENT"
225225
PART_OF_CHUNK = "PART_OF_CHUNK"
226226
NEXT_CHUNK = "NEXT_CHUNK"
@@ -263,23 +263,23 @@ async def run_async(self, path: Path, file_data: FileData, **kwargs) -> None: #
263263
async def _create_uniqueness_constraints(self, client: AsyncDriver) -> None:
264264
for label in Label:
265265
logger.info(
266-
f"Adding id uniqueness constraint for nodes labeled '{label}'"
266+
f"Adding id uniqueness constraint for nodes labeled '{label.value}'"
267267
" if it does not already exist."
268268
)
269-
constraint_name = f"{label.lower()}_id"
269+
constraint_name = f"{label.value.lower()}_id"
270270
await client.execute_query(
271271
f"""
272272
CREATE CONSTRAINT {constraint_name} IF NOT EXISTS
273-
FOR (n: {label}) REQUIRE n.id IS UNIQUE
273+
FOR (n: {label.value}) REQUIRE n.id IS UNIQUE
274274
"""
275275
)
276276

277277
async def _delete_old_data_if_exists(self, file_data: FileData, client: AsyncDriver) -> None:
278278
logger.info(f"Deleting old data for the record '{file_data.identifier}' (if present).")
279279
_, summary, _ = await client.execute_query(
280280
f"""
281-
MATCH (n: {Label.DOCUMENT} {{id: $identifier}})
282-
MATCH (n)--(m: {Label.CHUNK}|{Label.UNSTRUCTURED_ELEMENT})
281+
MATCH (n: {Label.DOCUMENT.value} {{id: $identifier}})
282+
MATCH (n)--(m: {Label.CHUNK.value}|{Label.UNSTRUCTURED_ELEMENT.value})
283283
DETACH DELETE m""",
284284
identifier=file_data.identifier,
285285
)
@@ -349,7 +349,7 @@ async def _execute_queries(
349349

350350
@staticmethod
351351
def _create_nodes_query(nodes: list[_Node], labels: tuple[Label, ...]) -> tuple[str, dict]:
352-
labels_string = ", ".join(labels)
352+
labels_string = ", ".join([label.value for label in labels])
353353
logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{labels_string}'.")
354354
query_string = f"""
355355
UNWIND $nodes AS node
@@ -366,7 +366,7 @@ def _create_edges_query(edges: list[_Edge], relationship: Relationship) -> tuple
366366
UNWIND $edges AS edge
367367
MATCH (u {{id: edge.source}})
368368
MATCH (v {{id: edge.destination}})
369-
MERGE (u)-[:{relationship}]->(v)
369+
MERGE (u)-[:{relationship.value}]->(v)
370370
"""
371371
parameters = {
372372
"edges": [

0 commit comments

Comments
 (0)