Skip to content

Commit a912368

Browse files
chore: mypy fixes (#6)
* chore: mypy initial fixes * chore: poetry config change + spanner import type ignore * chore: isort fix
1 parent ebacbc3 commit a912368

File tree

8 files changed

+44
-29
lines changed

8 files changed

+44
-29
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ pip-log.txt
3030
.cache
3131
.pytest_cache
3232

33+
# Mypy cache
34+
.mypy_cache
3335

3436
# Mac
3537
.DS_Store

poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,9 @@ build-backend = "poetry.core.masonry.api"
3030
target-version = ['py39']
3131

3232
[tool.isort]
33-
profile = "black"
33+
profile = "black"
34+
35+
[tool.mypy]
36+
python_version = 3.9
37+
warn_unused_configs = true
38+
ignore_missing_imports = true

src/llama_index_spanner/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from .version import __version__
1818

1919
__all__ = [
20-
__version__,
21-
SpannerPropertyGraphStore,
22-
SpannerGraphTextToGQLRetriever,
23-
SpannerGraphCustomRetriever,
20+
"__version__",
21+
"SpannerPropertyGraphStore",
22+
"SpannerGraphTextToGQLRetriever",
23+
"SpannerGraphCustomRetriever",
2424
]

src/llama_index_spanner/property_graph_store.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
import itertools
1616
import json
17-
from typing import Any, Dict, List, Optional, Tuple
17+
from typing import Any, Dict, List, Optional, Sequence, Tuple
1818

19-
from google.cloud import spanner
19+
from google.cloud import spanner # type: ignore
2020
from llama_index.core.graph_stores.types import (
2121
ChunkNode,
2222
EntityNode,
@@ -26,7 +26,7 @@
2626
Triplet,
2727
)
2828
from llama_index.core.prompts import PromptTemplate, PromptType
29-
from llama_index.core.vector_stores.types import VectorStoreQuery
29+
from llama_index.core.vector_stores.types import MetadataFilter, VectorStoreQuery
3030

3131
from .prompts import DEFAULT_SPANNER_GQL_TEMPLATE
3232
from .schema import (
@@ -57,7 +57,7 @@ def node_from_json(label: str, json_node_properties: Dict[str, Any]) -> Labelled
5757
Returns:
5858
A LabelledNode.
5959
"""
60-
id_, name, text, properties, embedding = None, None, None, {}, None
60+
id_, name, text, properties, embedding = "", "", "", {}, None
6161
for k, v in json_node_properties.items():
6262
if k == ElementSchema.NODE_KEY_COLUMN_NAME:
6363
id_ = v
@@ -93,7 +93,7 @@ def edge_from_json(json_edge_properties: Dict[str, Any]) -> Relation:
9393
Returns:
9494
A Relation.
9595
"""
96-
source_id, target_id, properties, label = None, None, {}, None
96+
source_id, target_id, properties, label = "", "", {}, ""
9797
for k, v in json_edge_properties.items():
9898
if k == ElementSchema.NODE_KEY_COLUMN_NAME:
9999
source_id = v
@@ -116,7 +116,7 @@ def edge_from_json(json_edge_properties: Dict[str, Any]) -> Relation:
116116
def update_condition(
117117
cond: List[str],
118118
params: Dict[str, Any],
119-
schema: SpannerGraphSchema = None,
119+
schema: SpannerGraphSchema,
120120
ids: Optional[List[str]] = None,
121121
properties: Optional[Dict[str, Any]] = None,
122122
entity_names: Optional[List[str]] = None,
@@ -251,7 +251,7 @@ def __init__(
251251
def client(self):
252252
return self.impl
253253

254-
def upsert_nodes(self, nodes: List[LabelledNode]) -> None:
254+
def upsert_nodes(self, nodes: Sequence[LabelledNode]) -> None:
255255
"""Upserts nodes into the graph store.
256256
257257
This method takes a list of LabelledNodes and upserts them into the
@@ -372,7 +372,7 @@ def get(
372372
else "labels(n)[0]"
373373
)
374374
cond = ["1 = 1"]
375-
params = {}
375+
params: Dict[str, Any] = {}
376376

377377
if not update_condition(
378378
cond, params, self.schema, ids=ids, properties=properties
@@ -416,7 +416,7 @@ def get_triplets(
416416
label_field = ElementSchema.DYNAMIC_LABEL_COLUMN_NAME
417417

418418
cond = ["1 = 1"]
419-
params = {}
419+
params: Dict[str, Any] = {}
420420

421421
if not update_condition(
422422
cond,
@@ -550,8 +550,8 @@ def delete(
550550
node_key_field = ElementSchema.NODE_KEY_COLUMN_NAME
551551
target_node_key_field = ElementSchema.TARGET_NODE_KEY_COLUMN_NAME
552552

553-
cond = []
554-
params = {}
553+
cond: List[str] = []
554+
params: Dict[str, Any] = {}
555555

556556
if (
557557
update_condition(
@@ -580,7 +580,7 @@ def delete(
580580
):
581581
self.impl.delete(
582582
self.schema.labels[node_table_label].base_table_name,
583-
[node_data[1:] for node_data in nodes],
583+
[[node_id] for _, node_id in nodes],
584584
)
585585

586586
if relation_names and self.schema.edge_tables:
@@ -610,7 +610,10 @@ def delete(
610610
):
611611
self.impl.delete(
612612
self.schema.labels[edge_label].base_table_name,
613-
[edge_data[1:] for edge_data in edges],
613+
[
614+
[edge_id, edge_target_id, edge_label]
615+
for _, edge_id, edge_target_id, edge_label in edges
616+
],
614617
)
615618
else:
616619
data = self.structured_query(
@@ -669,7 +672,7 @@ def vector_query(
669672
- A list of LabelledNodes, representing the nodes that were found.
670673
- A list of floats, representing the similarity scores of the nodes.
671674
"""
672-
if not self.schema.graph_exists:
675+
if not self.schema.graph_exists or query.query_embedding is None:
673676
return ([], [])
674677

675678
query_condition = "1 = 1"
@@ -680,6 +683,11 @@ def vector_query(
680683
if query.filters:
681684
cond = []
682685
for i, query_filter in enumerate(query.filters.filters):
686+
if not isinstance(
687+
query_filter, MetadataFilter
688+
): # doesn't support nested MetadataFilters
689+
continue
690+
683691
if (
684692
query_filter.key not in self.schema.node_properties
685693
and (property_prefix + query_filter.key)

src/llama_index_spanner/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import json
1818
import re
19-
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple
19+
from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, Tuple
2020

2121
from google.cloud.spanner_v1 import JsonObject, param_types
2222
from llama_index.core.graph_stores.types import (
@@ -38,7 +38,7 @@ def remove_empty_values(input_dict):
3838

3939

4040
def group_nodes(
41-
nodes: List[LabelledNode],
41+
nodes: Sequence[LabelledNode],
4242
) -> Dict[str, List[LabelledNode]]:
4343
"""Groups nodes by their respective types.
4444

src/llama_index_spanner/spanner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from abc import ABC, abstractmethod
1616
from typing import Any, Dict, List, Optional, Tuple
1717

18-
from google.cloud import spanner
18+
from google.cloud import spanner # type: ignore
1919

2020
from .type_utils import TypeUtility
2121

@@ -65,7 +65,7 @@ def apply_ddls(self, ddls: List[str], options: Dict[str, Any] = {}) -> None:
6565

6666
@abstractmethod
6767
def insert_or_update(
68-
self, table: str, columns: Tuple[str], values: List[List[Any]]
68+
self, table: str, columns: Tuple[str, ...], values: List[List[Any]]
6969
) -> None:
7070
"""Insert or update the table.
7171
@@ -130,7 +130,7 @@ def apply_ddls(self, ddls: List[str], options: Dict[str, Any] = {}) -> None:
130130
return op.result(options.get("timeout", DEFAULT_DDL_TIMEOUT))
131131

132132
def insert_or_update(
133-
self, table: str, columns: Tuple[str], values: List[List[Any]]
133+
self, table: str, columns: Tuple[str, ...], values: List[List[Any]]
134134
) -> None:
135135
for i in range(0, len(values), MUTATION_BATCH_SIZE):
136136
value_batch = values[i : i + MUTATION_BATCH_SIZE]

tests/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def get_resources(
5555
graph_store = get_spanner_property_graph_store(
5656
graph_name_suffix, use_flexible_schema, clean_up
5757
)
58-
storage_context = StorageContext.from_defaults(graph_store=graph_store)
58+
storage_context = StorageContext.from_defaults(property_graph_store=graph_store)
5959
llm = GoogleGenAI(
6060
model="gemini-2.0-flash",
6161
)

0 commit comments

Comments
 (0)