Skip to content

Commit e269fa3

Browse files
committed
chore: update imports in tests + update types in schema
1 parent ac0648c commit e269fa3

File tree

6 files changed

+29
-46
lines changed

6 files changed

+29
-46
lines changed

src/llama_index_spanner/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from .version import __version__
16+
from .property_graph_store import SpannerPropertyGraphStore
1517
from .graph_retriever import (
1618
SpannerGraphCustomRetriever,
1719
SpannerGraphTextToGQLRetriever,
1820
)
19-
from .property_graph_store import SpannerPropertyGraphStore
20-
from .version import __version__
2121

2222
__all__ = [
2323
__version__,

src/llama_index_spanner/schema.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def remove_empty_values(input_dict):
3838

3939
def group_nodes(
4040
nodes: List[LabelledNode],
41-
) -> dict[str, List[LabelledNode]]:
41+
) -> Dict[str, List[LabelledNode]]:
4242
"""Groups nodes by their respective types.
4343
4444
Args:
@@ -47,7 +47,7 @@ def group_nodes(
4747
Returns:
4848
A dictionary mapping node types to lists of LabelledNodes.
4949
"""
50-
nodes_group: CaseInsensitiveDict[dict[str, LabelledNode]] = (
50+
nodes_group: CaseInsensitiveDict[Dict[str, LabelledNode]] = (
5151
CaseInsensitiveDict()
5252
)
5353
for node in nodes:
@@ -63,7 +63,7 @@ def group_nodes(
6363

6464
def group_edges(
6565
edges: List[Tuple[Relation, str, str, str, str]],
66-
) -> dict[str, List[Tuple[Relation, str, str]]]:
66+
) -> Dict[str, List[Tuple[Relation, str, str]]]:
6767
"""Groups edges by their respective types.
6868
6969
Args:
@@ -79,7 +79,7 @@ def group_edges(
7979
- The target node table name.
8080
"""
8181
edges_group: CaseInsensitiveDict[
82-
dict[Tuple[str, str, str], Tuple[Relation, str, str]]
82+
Dict[Tuple[str, str, str], Tuple[Relation, str, str]]
8383
] = CaseInsensitiveDict()
8484
for edge, source_label, source_table, target_label, target_table in edges:
8585
edge_name = (
@@ -159,9 +159,9 @@ class ElementSchema(object):
159159
key_columns: List[str]
160160
base_table_name: str
161161
labels: List[str]
162-
properties: Dict[str, str]
162+
properties: CaseInsensitiveDict[str]
163163
# types: A dictionary where keys are property names (strings) and values are Spanner type definitions
164-
types: Dict[str, param_types.Type]
164+
types: CaseInsensitiveDict[param_types.Type]
165165
source: NodeReference
166166
target: NodeReference
167167

@@ -181,7 +181,7 @@ def is_dynamic_schema(self) -> bool:
181181
def make_node_schema(
182182
node_label: str,
183183
graph_name: str,
184-
property_types: Dict[str, param_types.Type],
184+
property_types: CaseInsensitiveDict[param_types.Type],
185185
) -> ElementSchema:
186186
"""Creates a node schema for a given node type and label."""
187187
node = ElementSchema()
@@ -199,7 +199,7 @@ def make_edge_schema(
199199
edge_label: str,
200200
graph_schema: SpannerGraphSchema,
201201
key_columns: List[str],
202-
property_types: Dict[str, param_types.Type],
202+
property_types: CaseInsensitiveDict[param_types.Type],
203203
source_node_table: str,
204204
target_node_table: str,
205205
) -> ElementSchema:
@@ -741,20 +741,20 @@ def __init__(
741741
)
742742

743743
self.graph_name: str = graph_name
744-
self.node_tables: Dict[str, ElementSchema] = CaseInsensitiveDict({})
745-
self.edge_tables: Dict[str, ElementSchema] = CaseInsensitiveDict({})
746-
self.labels: Dict[str, Label] = CaseInsensitiveDict({})
747-
self.properties: Dict[str, param_types.Type] = CaseInsensitiveDict({})
748-
self.node_properties: Dict[str, param_types.Type] = CaseInsensitiveDict({})
744+
self.node_tables: CaseInsensitiveDict[ElementSchema] = CaseInsensitiveDict({})
745+
self.edge_tables: CaseInsensitiveDict[ElementSchema] = CaseInsensitiveDict({})
746+
self.labels: CaseInsensitiveDict[Label] = CaseInsensitiveDict({})
747+
self.properties: CaseInsensitiveDict[param_types.Type] = CaseInsensitiveDict({})
748+
self.node_properties: CaseInsensitiveDict[param_types.Type] = CaseInsensitiveDict({})
749749
self.use_flexible_schema = use_flexible_schema
750750
self.static_node_properties = set(static_node_properties or [])
751751
self.static_edge_properties = set(static_edge_properties or [])
752752
self.graph_exists = False
753753

754754
def evolve_from_nodes(
755755
self,
756-
nodes: dict[str, List[LabelledNode]],
757-
) -> Tuple[List[str], dict[str, ElementSchema]]:
756+
nodes: Dict[str, List[LabelledNode]],
757+
) -> Tuple[List[str], Dict[str, ElementSchema]]:
758758
"""Evolves the graph schema based on new nodes and edges.
759759
760760
This method updates the internal schema representation by adding new
@@ -789,8 +789,8 @@ def evolve_from_nodes(
789789

790790
def evolve_from_edges(
791791
self,
792-
edges: dict[str, List[Tuple[Relation, str, str]]],
793-
) -> Tuple[List[str], dict[str, ElementSchema], dict[str, ElementSchema]]:
792+
edges: Dict[str, List[Tuple[Relation, str, str]]],
793+
) -> Tuple[List[str], Dict[str, ElementSchema]]:
794794
"""Evolves the graph schema based on new edges.
795795
796796
This method updates the internal schema representation by adding new
@@ -902,7 +902,7 @@ def to_ddl(self) -> str:
902902

903903
def construct_label_and_properties(
904904
target_label: str,
905-
labels: Dict[str, Label],
905+
labels: CaseInsensitiveDict[Label],
906906
element: ElementSchema,
907907
) -> str:
908908
props = labels[target_label].prop_names
@@ -917,7 +917,7 @@ def construct_label_and_properties(
917917

918918
def construct_label_and_properties_list(
919919
target_labels: List[str],
920-
labels: Dict[str, Label],
920+
labels: CaseInsensitiveDict[Label],
921921
element: ElementSchema,
922922
) -> str:
923923
return "\n".join((
@@ -942,7 +942,7 @@ def construct_node_reference(
942942
)
943943

944944
def construct_element_table(
945-
element: ElementSchema, labels: Dict[str, Label]
945+
element: ElementSchema, labels: CaseInsensitiveDict[Label]
946946
) -> str:
947947
definition = [
948948
"{} AS {}".format(

tests/__init__.py

Whitespace-only changes.

tests/test_graph_retriever.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from llama_index.core.query_engine import RetrieverQueryEngine
2424
from llama_index.readers.wikipedia import WikipediaReader
2525
import pytest
26-
from src.llama_index_spanner.graph_retriever import (
26+
from llama_index_spanner.graph_retriever import (
2727
SpannerGraphCustomRetriever,
2828
SpannerGraphTextToGQLRetriever,
2929
)

tests/test_spanner_property_graph_store.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import datetime
1615
from typing import Generator
1716
from llama_index.core.graph_stores.types import (
1817
ChunkNode,
1918
EntityNode,
20-
LabelledNode,
21-
PropertyGraphStore,
2219
Relation,
2320
)
2421
from llama_index.core.vector_stores.types import (
@@ -30,8 +27,8 @@
3027
VectorStoreQuery,
3128
)
3229
import pytest
33-
from src.llama_index_spanner.property_graph_store import SpannerPropertyGraphStore
34-
from src.llama_index_spanner.schema import ElementSchema
30+
from llama_index_spanner.property_graph_store import SpannerPropertyGraphStore
31+
from llama_index_spanner.schema import ElementSchema
3532
from tests.utils import (
3633
get_random_suffix,
3734
get_spanner_property_graph_store,
@@ -55,7 +52,7 @@
5552

5653

5754
@pytest.fixture
58-
def property_graph_store_static() -> Generator[SpannerPropertyGraphStore]:
55+
def property_graph_store_static() -> Generator[SpannerPropertyGraphStore, None, None]:
5956
"""Provides a fresh SpannerPropertyGraphStore for each test."""
6057
graph_store = get_spanner_property_graph_store(
6158
graph_name_suffix=get_random_suffix()
@@ -65,7 +62,7 @@ def property_graph_store_static() -> Generator[SpannerPropertyGraphStore]:
6562

6663

6764
@pytest.fixture
68-
def property_graph_store_dynamic() -> Generator[SpannerPropertyGraphStore]:
65+
def property_graph_store_dynamic() -> Generator[SpannerPropertyGraphStore, None, None]:
6966
"""Provides a fresh SpannerPropertyGraphStore for each test."""
7067
graph_store_dynamic_schema = get_spanner_property_graph_store(
7168
use_flexible_schema=True,

tests/utils.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from llama_index.core.storage import StorageContext
1919
from llama_index.embeddings.google_genai import GoogleGenAIEmbedding
2020
from llama_index.llms.google_genai import GoogleGenAI
21-
from src.llama_index_spanner import SpannerGraphStore, SpannerPropertyGraphStore
21+
from llama_index_spanner import SpannerPropertyGraphStore
2222

2323
spanner_instance_id = (
2424
os.environ.get("SPANNER_INSTANCE_ID") or "graphdb-spanner-llama"
@@ -29,27 +29,13 @@
2929
spanner_graph_name = os.environ.get("SPANNER_GRAPH_NAME") or "llama_index_graph"
3030

3131

32-
def get_spanner_graph_store(
33-
graph_name_suffix: str = "", clean_up: bool = False
34-
) -> SpannerGraphStore:
35-
"""Get a SpannerGraphStore instance for testing."""
36-
graph_name = spanner_graph_name
37-
if graph_name_suffix:
38-
graph_name += "_" + graph_name_suffix
39-
return SpannerGraphStore(
40-
instance_id=spanner_instance_id,
41-
database_id=spanner_database_id,
42-
graph_name=graph_name,
43-
clean_up=clean_up,
44-
)
45-
4632

4733
def get_spanner_property_graph_store(
4834
graph_name_suffix: str = "",
4935
use_flexible_schema: bool = False,
5036
clean_up: bool = False,
5137
) -> SpannerPropertyGraphStore:
52-
"""Get a SpannerGraphStore instance for testing."""
38+
"""Get a SpannerPropertyGraphStore instance for testing."""
5339
graph_name = spanner_graph_name
5440
if graph_name_suffix:
5541
graph_name += "_" + graph_name_suffix

0 commit comments

Comments
 (0)