diff --git a/src/langchain_google_spanner/graph_store.py b/src/langchain_google_spanner/graph_store.py index 93b42ec0..34e40327 100644 --- a/src/langchain_google_spanner/graph_store.py +++ b/src/langchain_google_spanner/graph_store.py @@ -15,10 +15,22 @@ from __future__ import annotations import json +import logging import re import string from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union +from typing import ( + Any, + Dict, + Generator, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Union, +) from google.cloud import spanner from google.cloud.spanner_v1 import JsonObject, param_types @@ -35,6 +47,8 @@ EDGE_KIND = "EDGE" USER_AGENT_GRAPH_STORE = "langchain-google-spanner-python:graphstore/" + __version__ +logger = logging.getLogger(__name__) + class NodeWrapper(object): """Wrapper around Node to support set operations using node id""" @@ -205,18 +219,103 @@ class ElementSchema(object): source: NodeReference target: NodeReference + # DYNAMIC LABEL() + # DYNAMIC PROPERTIES() + dynamic_label_expr: Optional[str] = None + dynamic_property_expr: Optional[str] = None + + # Cache of dynamically fetched labels and properties. + dynamic_schema: Optional[CaseInsensitiveDict[DynamicLabel]] = None + # Cache of dynamically fetched edge patterns. + dynamic_edge_patterns: Optional[List[Tuple[str, str, str]]] = None + def is_dynamic_schema(self) -> bool: return ( - self.types.get(ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME, None) - == param_types.JSON + self.dynamic_label_expr is not None + or self.dynamic_property_expr is not None ) + def refresh_dynamic_schema(self, dynamic_schema_util: DynamicSchemaUtility): + if self.kind == NODE_KIND: + self.dynamic_schema = dynamic_schema_util.get_dynamic_node_schema( + self.labels + ) + else: + self.dynamic_schema = dynamic_schema_util.get_dynamic_edge_schema( + self.labels + ) + self.dynamic_edge_patterns = dynamic_schema_util.get_dynamic_edge_patterns( + self.labels + ) + + def get_label_and_properties(self, graph: SpannerGraphSchema): + + def get_readable_property(pname, ptype, json_type=None): + prop = { + "name": pname, + "type": TypeUtility.spanner_type_to_schema_str(ptype), + } + # Dynamic properties will have json_types: this represents the + # underlying data type of the json value. + if json_type: + prop["json_type"] = json_type + return prop + + if self.dynamic_schema: + return { + lname: [ + get_readable_property( + pname, self.types.get(pname, param_types.JSON), ptype + ) + for pname, ptype in label.properties + ] + for lname, label in self.dynamic_schema.items() + # Ignore static labels. + if lname not in self.labels + } + return { + label: [ + get_readable_property(pname, self.types[pname]) + for pname in sorted(graph.labels[label].prop_names) + if pname in self.types + ] + for label in sorted(self.labels) + } + + def get_edge_patterns(self, graph: SpannerGraphSchema): + assert self.kind == EDGE_KIND + source = graph.get_node_schema(self.source.node_name) + assert source is not None + target = graph.get_node_schema(self.target.node_name) + assert target is not None + if self.dynamic_edge_patterns: + return [ + (source_node_label, label, target_node_label) + for ( + source_node_label, + label, + target_node_label, + ) in self.dynamic_edge_patterns + # Ignore static labels. + if label not in self.labels + and source_node_label not in source.labels + and target_node_label not in target.labels + ] + return [ + (source_node_label, label, target_node_label) + for label in sorted(self.labels) + for source_node_label in source.labels + for target_node_label in target.labels + ] + @staticmethod def make_node_schema( node_type: str, node_label: str, graph_name: str, property_types: CaseInsensitiveDict, + dynamic_label_expr: Optional[str] = None, + dynamic_property_expr: Optional[str] = None, ) -> ElementSchema: node = ElementSchema() node.types = property_types @@ -226,6 +325,8 @@ def make_node_schema( node.name = node_type node.kind = NODE_KIND node.key_columns = [ElementSchema.NODE_KEY_COLUMN_NAME] + node.dynamic_label_expr = dynamic_label_expr + node.dynamic_property_expr = dynamic_property_expr return node @staticmethod @@ -237,6 +338,8 @@ def make_edge_schema( property_types: CaseInsensitiveDict, source_node_type: str, target_node_type: str, + dynamic_label_expr: Optional[str] = None, + dynamic_property_expr: Optional[str] = None, ) -> ElementSchema: edge = ElementSchema() edge.types = property_types @@ -270,6 +373,8 @@ def make_edge_schema( [ElementSchema.NODE_KEY_COLUMN_NAME], [ElementSchema.TARGET_NODE_KEY_COLUMN_NAME], ) + edge.dynamic_label_expr = dynamic_label_expr + edge.dynamic_property_expr = dynamic_property_expr return edge @staticmethod @@ -351,7 +456,12 @@ def from_dynamic_nodes( ) ) return ElementSchema.make_node_schema( - NODE_KIND, NODE_KIND, graph_schema.graph_name, types + NODE_KIND, + NODE_KIND, + graph_schema.graph_name, + types, + dynamic_label_expr=ElementSchema.DYNAMIC_LABEL_COLUMN_NAME, + dynamic_property_expr=ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME, ) @staticmethod @@ -468,6 +578,8 @@ def from_dynamic_edges( types, edges[0].source.type, edges[0].target.type, + dynamic_label_expr=ElementSchema.DYNAMIC_LABEL_COLUMN_NAME, + dynamic_property_expr=ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME, ) def add_nodes( @@ -499,6 +611,19 @@ def add_nodes( properties[ElementSchema.NODE_KEY_COLUMN_NAME] = node.id if self.is_dynamic_schema(): + assert ( + self.dynamic_label_expr == ElementSchema.DYNAMIC_LABEL_COLUMN_NAME + ), "Require dynamic label expression to be %s: got %s" % ( + ElementSchema.DYNAMIC_LABEL_COLUMN_NAME, + self.dynamic_label_expr, + ) + assert ( + self.dynamic_property_expr + == ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME + ), "Require dynamic property expression to be %s: got %s" % ( + ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME, + self.dynamic_property_expr, + ) dynamic_properties = { k: TypeUtility.value_for_json(v) for k, v in node.properties.items() @@ -549,6 +674,19 @@ def add_edges( properties[ElementSchema.TARGET_NODE_KEY_COLUMN_NAME] = edge.target.id if self.is_dynamic_schema(): + assert ( + self.dynamic_label_expr == ElementSchema.DYNAMIC_LABEL_COLUMN_NAME + ), "Require dynamic label expression to be %s: got %s" % ( + ElementSchema.DYNAMIC_LABEL_COLUMN_NAME, + self.dynamic_label_expr, + ) + assert ( + self.dynamic_property_expr + == ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME + ), "Require dynamic property expression to be %s: got %s" % ( + ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME, + self.dynamic_property_expr, + ) dynamic_properties = { k: TypeUtility.value_for_json(v) for k, v in edge.properties.items() @@ -618,6 +756,9 @@ def from_info_schema( element_schema["destinationNodeTable"]["nodeTableColumns"], element_schema["destinationNodeTable"]["edgeTableColumns"], ) + + element.dynamic_label_expr = element_schema.get("dynamicLabelExpr") + element.dynamic_property_expr = element_schema.get("dynamicPropertyExpr") return element def to_ddl(self, graph_schema: SpannerGraphSchema) -> str: @@ -743,6 +884,11 @@ def evolve(self, new_schema: ElementSchema) -> List[str]: ] self.properties.update(new_schema.properties) self.types.update(new_schema.types) + + self.dynamic_label_expr = new_schema.dynamic_label_expr + self.dynamic_property_expr = new_schema.dynamic_property_expr + self.dynamic_schema = new_schema.dynamic_schema + self.dynamic_edge_patterns = new_schema.dynamic_edge_patterns return ddls @@ -753,6 +899,9 @@ def __init__(self, name: str, prop_names: set[str]): self.name = name self.prop_names = prop_names + def __repr__(self): + return f"Label({self.name}, {self.prop_names})" + class NodeReference(object): """Schema representation of a source or destination node reference.""" @@ -763,6 +912,112 @@ def __init__(self, node_name: str, node_keys: List[str], edge_keys: List[str]): self.edge_keys = edge_keys +class DynamicLabel(object): + """Representation of a dynamic label.""" + + def __init__(self, name: str, properties: List[Tuple[str, str]]): + self.name = name + self.properties = properties + + +class DynamicSchemaUtility(object): + """Utility class that dynamically fetches graph schema.""" + + # Sample a list of (label, properties) for nodes of static label_expr. + NODE_DYNAMIC_SCHEMA_QUERY_TEMPLATE = """ + GRAPH `{graph_id}` + MATCH (n:{label_expr}) + LET json = SAFE_TO_JSON(n).properties + FOR label IN LABELS(n) + RETURN label, ANY_VALUE(json) AS json + NEXT + LET json_fields = JSON_KEYS(json) + RETURN label, ARRAY {{ + GRAPH `{graph_id}` + FOR field IN json_fields + FILTER json[field] IS NOT NULL + LET type = JSON_TYPE(json[field]) + FILTER type != 'null' + RETURN STRUCT(field, type) AS field + }} AS properties + """ + + # Sample a list of (label, properties) for edges of static label_expr. + EDGE_DYNAMIC_SCHEMA_QUERY_TEMPLATE = """ + GRAPH `{graph_id}` + MATCH -[n:{label_expr}]-> + LET json = SAFE_TO_JSON(n).properties + FOR label IN LABELS(n) + RETURN label, ANY_VALUE(json) AS json + NEXT + LET json_fields = JSON_KEYS(json) + RETURN label, ARRAY {{ + GRAPH `{graph_id}` + FOR field IN json_fields + FILTER json[field] IS NOT NULL + LET type = JSON_TYPE(json[field]) + FILTER type != 'null' + RETURN STRUCT(field, type) AS property + ORDER BY field + }} AS properties + ORDER BY label + """ + + # Find all (source_node_label, edge_label, target_node_label) triplets. + EDGE_PATTERN_QUERY_TEMPLATE = """ + GRAPH `{graph_id}` + MATCH (src) -[n:{label_expr}]-> (dst) + FOR edge_label IN LABELS(n) + FOR src_label IN LABELS(src) + FOR dst_label IN LABELS(dst) + RETURN DISTINCT src_label, edge_label, dst_label + ORDER BY src_label, edge_label, dst_label + """ + + def __init__(self, graph_name: str, impl: SpannerInterface): + self._graph_name = graph_name + self._impl = impl + + @staticmethod + def make_label_expr(labels: List[str]) -> str: + return " & ".join([f"`{label}`" for label in labels]) + + def get_dynamic_node_schema(self, labels: List[str]): + return self._get_dynamic_schema(labels, self.NODE_DYNAMIC_SCHEMA_QUERY_TEMPLATE) + + def get_dynamic_edge_schema(self, labels: List[str]): + return self._get_dynamic_schema(labels, self.EDGE_DYNAMIC_SCHEMA_QUERY_TEMPLATE) + + def _get_dynamic_schema(self, labels: List[str], query_template: str): + label_expr = self.make_label_expr(labels) + return CaseInsensitiveDict( + { + row["label"]: DynamicLabel( + name=row["label"], + properties=row["properties"], + ) + for row in self._impl.query( + query_template.format( + graph_id=self._graph_name, label_expr=label_expr + ) + ) + } + ) + + def get_dynamic_edge_patterns(self, labels: List[str]): + return set( + { + (row["src_label"], row["edge_label"], row["dst_label"]) + for row in self._impl.query( + self.EDGE_PATTERN_QUERY_TEMPLATE.format( + graph_id=self._graph_name, + label_expr=self.make_label_expr(labels), + ) + ) + } + ) + + class SpannerGraphSchema(object): """Schema representation of a property graph.""" @@ -778,6 +1033,7 @@ def __init__( use_flexible_schema: bool, static_node_properties: List[str] = [], static_edge_properties: List[str] = [], + dynamic_schema_util: Optional[DynamicSchemaUtility] = None, ): """Initializes the graph schema. @@ -808,6 +1064,7 @@ def __init__( self.use_flexible_schema = use_flexible_schema self.static_node_properties = set(static_node_properties) self.static_edge_properties = set(static_edge_properties) + self.dynamic_schema_util = dynamic_schema_util def evolve(self, graph_documents: List[GraphDocument]) -> List[str]: """Evolves current schema into a schema representing the input documents. @@ -859,11 +1116,15 @@ def from_information_schema(self, info_schema: Dict[str, Any]) -> None: ) for node in info_schema["nodeTables"]: node_schema = ElementSchema.from_info_schema(node, decl_by_types) + if node_schema.is_dynamic_schema() and self.dynamic_schema_util: + node_schema.refresh_dynamic_schema(self.dynamic_schema_util) self._update_node_schema(node_schema) self._update_labels_and_properties(node_schema) for edge in info_schema.get("edgeTables", []): edge_schema = ElementSchema.from_info_schema(edge, decl_by_types) + if edge_schema.is_dynamic_schema() and self.dynamic_schema_util: + edge_schema.refresh_dynamic_schema(self.dynamic_schema_util) self._update_edge_schema(edge_schema) self._update_labels_and_properties(edge_schema) @@ -931,63 +1192,33 @@ def __repr__(self) -> str: Returns: str: a string representation of the graph schema. """ - properties = CaseInsensitiveDict( - { - k: TypeUtility.spanner_type_to_schema_str(v) - for k, v in self.properties.items() - } - ) - node_labels = {label for node in self.nodes.values() for label in node.labels} - edge_labels = {label for edge in self.edges.values() for label in edge.labels} - Triplet = Tuple[ElementSchema, ElementSchema, ElementSchema] - triplets_per_label: CaseInsensitiveDict[List[Triplet]] = CaseInsensitiveDict({}) + node_properties_per_label: Dict[str, Dict] = {} + edge_properties_per_label: Dict[str, Dict] = {} + edge_patterns_per_label: Dict[str, Set[str]] = {} + for node in self.nodes.values(): + node_properties_per_label.update(node.get_label_and_properties(self)) + for edge in self.edges.values(): - for label in edge.labels: - source_node = self.get_node_schema(edge.source.node_name) - target_node = self.get_node_schema(edge.target.node_name) - if source_node is None: - raise ValueError(f"Source node {edge.source.node_name} not found") - if target_node is None: - raise ValueError(f"Tource node {edge.target.node_name} not found") - triplets_per_label.setdefault(label, []).append( - (source_node, edge, target_node) + edge_properties_per_label.update(edge.get_label_and_properties(self)) + for src_node_label, label, tgt_node_label in edge.get_edge_patterns(self): + edge_patterns_per_label.setdefault(label, set()).add( + "(:{}) -[:{}]-> (:{})".format(src_node_label, label, tgt_node_label) ) return json.dumps( { "Name of graph": self.graph_name, - "Node properties per node label": { - label: [ - { - "name": name, - "type": properties[name], - } - for name in sorted(self.labels[label].prop_names) - ] - for label in sorted(node_labels) - }, - "Edge properties per edge label": { - label: [ - { - "name": name, - "type": properties[name], - } - for name in sorted(self.labels[label].prop_names) - ] - for label in sorted(edge_labels) - }, - "Possible edges per label": { - label: [ - "(:{}) -[:{}]-> (:{})".format( - source_node_label, label, target_node_label - ) - for (source, edge, target) in triplets - for source_node_label in source.labels - for target_node_label in target.labels - ] - for label, triplets in triplets_per_label.items() - }, + "Node properties per node label": dict( + sorted(node_properties_per_label.items()) + ), + "Edge properties per edge label": dict( + sorted(edge_properties_per_label.items()) + ), + "Possible edges per label": dict( + sorted(edge_patterns_per_label.items()) + ), }, indent=2, + default=lambda s: sorted(s), ) def to_ddl(self) -> str: @@ -1019,12 +1250,17 @@ def construct_label_and_properties_list( labels: CaseInsensitiveDict[Label], element: ElementSchema, ) -> str: - return "\n".join( - ( - construct_label_and_properties(target_label, labels, element) - for target_label in target_labels + clauses = [ + construct_label_and_properties(target_label, labels, element) + for target_label in target_labels + ] + if element.dynamic_label_expr: + clauses.append("DYNAMIC LABEL ({})".format(element.dynamic_label_expr)) + if element.dynamic_property_expr: + clauses.append( + "DYNAMIC PROPERTIES ({})".format(element.dynamic_property_expr) ) - ) + return "\n".join(clauses) def construct_columns(cols: List[str]) -> str: return ", ".join(to_identifiers(cols)) @@ -1176,7 +1412,6 @@ def add_edges( """ edge_schema = self.get_edge_schema(self.edge_type_name(name)) if edge_schema is None: - print(list(self.edges.keys())) raise ValueError("Unknown edge schema `%s`" % name) for v in edge_schema.add_edges(name, edges): yield v @@ -1265,7 +1500,7 @@ def apply_ddls(self, ddls: List[str], options: Dict[str, Any] = {}) -> None: return op = self.database.update_ddl(ddl_statements=ddls) - print("Waiting for DDL operations to complete...") + logger.info("Waiting for DDL operations to complete...") return op.result(options.get("timeout", DEFAULT_DDL_TIMEOUT)) def insert_or_update( @@ -1307,6 +1542,7 @@ def __init__( properties as static. timeout (Optional[float]): The timeout for queries in seconds. """ + self.graph_name = graph_name self.impl = impl or SpannerImpl( instance_id, database_id, @@ -1316,8 +1552,9 @@ def __init__( self.schema = SpannerGraphSchema( graph_name, use_flexible_schema, - static_node_properties, - static_edge_properties, + static_node_properties=static_node_properties, + static_edge_properties=static_edge_properties, + dynamic_schema_util=DynamicSchemaUtility(graph_name, self.impl), ) self.refresh_schema() @@ -1345,25 +1582,28 @@ def add_graph_documents( ddls = self.schema.evolve(graph_documents) if ddls: self.impl.apply_ddls(ddls) - self.refresh_schema() else: - print("No schema change required...") + logger.info("No schema change required...") nodes, edges = partition_graph_docs(graph_documents) for name, elements in nodes.items(): if len(elements) == 0: continue for table, columns, rows in self.schema.add_nodes(name, elements): - print("Insert nodes of type `{}`...".format(name)) + logger.info("Insert nodes of type `{}`...".format(name)) self.impl.insert_or_update(table, columns, rows) for name, elements in edges.items(): if len(elements) == 0: continue for table, columns, rows in self.schema.add_edges(name, elements): - print("Insert edges of type `{}`...".format(name)) + logger.info("Insert edges of type `{}`...".format(name)) self.impl.insert_or_update(table, columns, rows) + # Refresh schema after data insertion because json property is sampled + # over the actual data. + self.refresh_schema() + def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: """Query Spanner database. @@ -1435,5 +1675,9 @@ def cleanup(self): ] ) self.schema = SpannerGraphSchema( - self.schema.graph_name, self.schema.use_flexible_schema + self.schema.graph_name, + self.schema.use_flexible_schema, + self.schema.static_node_properties, + self.schema.static_edge_properties, + self.schema.dynamic_schema_util, ) diff --git a/tests/integration/test_spanner_graph_store.py b/tests/integration/test_spanner_graph_store.py index 271dba30..36536fc0 100644 --- a/tests/integration/test_spanner_graph_store.py +++ b/tests/integration/test_spanner_graph_store.py @@ -163,43 +163,49 @@ def random_graph_doc(suffix): ) +@pytest.fixture +def setup_graph(request): + use_flexible_schema = request.getfixturevalue("use_flexible_schema") + suffix = random_string(num_char=5, exclude_whitespaces=True) + graph_name = "test_graph{}".format(suffix) + graph = SpannerGraphStore( + instance_id, + google_database, + graph_name, + client=Client(project=project_id), + use_flexible_schema=use_flexible_schema, + static_node_properties=["a", "b"], + static_edge_properties=["a", "b"], + ) + graph.refresh_schema() + + yield suffix, graph + + print("Clean up graph with name `{}`".format(graph.graph_name)) + graph.cleanup() + + class TestSpannerGraphStore: @pytest.mark.parametrize("use_flexible_schema", [False, True]) - def test_spanner_graph_random_doc(self, use_flexible_schema): - suffix = random_string(num_char=5, exclude_whitespaces=True) - graph_name = "test_graph{}".format(suffix) - graph = SpannerGraphStore( - instance_id, - google_database, - graph_name, - client=Client(project=project_id), - use_flexible_schema=use_flexible_schema, - static_node_properties=random_property_names( - random_int(l=0, u=len(properties)) - ), - static_edge_properties=random_property_names( - random_int(l=0, u=len(properties)) - ), - ) - graph.refresh_schema() - - try: - node_ids = set() - edge_ids = set() - for _ in range(3): - graph_doc = random_graph_doc(suffix) - graph.add_graph_documents([graph_doc]) - node_ids.update({(n.type, n.id) for n in graph_doc.nodes}) - edge_ids.update( - { - (e.type, e.source.id, e.target.id) - for e in graph_doc.relationships - } - ) - graph.refresh_schema() + def test_spanner_graph_random_doc( + self, + setup_graph, + use_flexible_schema, + ): + suffix, graph = setup_graph + node_ids = set() + edge_ids = set() + for _ in range(3): + graph_doc = random_graph_doc(suffix) + graph.add_graph_documents([graph_doc]) + node_ids.update({(n.type, n.id) for n in graph_doc.nodes}) + edge_ids.update( + {(e.type, e.source.id, e.target.id) for e in graph_doc.relationships} + ) + graph.refresh_schema() - results = graph.query( - """ + results = graph.query( + """ GRAPH {} MATCH -> @@ -215,66 +221,42 @@ def test_spanner_graph_random_doc(self, use_flexible_schema): RETURN type, num_elements, @param AS param ORDER BY type """.format( - graph_name - ), - params={"param": random_param()}, - ) - assert len(results) == 2 - assert results[0]["type"] == "edge", "Mismatch type" - assert results[0]["num_elements"] == len( - edge_ids - ), "Mismatch number of edges" - assert results[1]["type"] == "node", "Mismatch type" - assert results[1]["num_elements"] == len( - node_ids - ), "Mismatch number of nodes" - - finally: - print("Clean up graph with name `{}`".format(graph_name)) - print(graph.get_schema) - print(graph.get_structured_schema) - print(graph.get_ddl()) - graph.cleanup() + graph.graph_name + ), + params={"param": random_param()}, + ) + assert len(results) == 2 + assert results[0]["type"] == "edge", "Mismatch type" + assert results[0]["num_elements"] == len(edge_ids), "Mismatch number of edges" + assert results[1]["type"] == "node", "Mismatch type" + assert results[1]["num_elements"] == len(node_ids), "Mismatch number of nodes" @pytest.mark.parametrize("use_flexible_schema", [False, True]) - def test_spanner_graph_doc_with_duplicate_elements(self, use_flexible_schema): - suffix = random_string(num_char=5, exclude_whitespaces=True) - graph_name = "test_graph{}".format(suffix) - graph = SpannerGraphStore( - instance_id, - google_database, - graph_name, - client=Client(project=project_id), - use_flexible_schema=use_flexible_schema, - static_node_properties=random_property_names( - random_int(l=0, u=len(properties)) - ), - static_edge_properties=random_property_names( - random_int(l=0, u=len(properties)) + def test_spanner_graph_doc_with_duplicate_elements( + self, + setup_graph, + use_flexible_schema, + ): + suffix, graph = setup_graph + node0 = random_node("Node0{}".format(suffix)) + node1 = random_node("Node1{}".format(suffix)) + edge0 = random_edge("Edge01", node0, node1) + edge1 = random_edge("Edge01", node0, node1) + + doc = GraphDocument( + nodes=[node0, node1, node0, node1], + relationships=[edge0, edge1], + source=Document( + page_content="Hello, world!", + metadata={"source": "https://example.com"}, ), ) - graph.refresh_schema() + graph.add_graph_documents([doc]) - try: - node0 = random_node("Node0{}".format(suffix)) - node1 = random_node("Node1{}".format(suffix)) - edge0 = random_edge("Edge01", node0, node1) - edge1 = random_edge("Edge01", node0, node1) - - doc = GraphDocument( - nodes=[node0, node1, node0, node1], - relationships=[edge0, edge1], - source=Document( - page_content="Hello, world!", - metadata={"source": "https://example.com"}, - ), - ) - graph.add_graph_documents([doc]) - - # In the case of flexible schema, `properties` is a nested json - # field. - results = graph.query( - """ + # In the case of flexible schema, `properties` is a nested json + # field. + results = graph.query( + """ GRAPH {} MATCH -[e]-> @@ -282,72 +264,56 @@ def test_spanner_graph_doc_with_duplicate_elements(self, use_flexible_schema): RETURN COALESCE(properties.properties, JSON "{{}}") AS dynamic_properties, properties AS static_properties """.format( - graph_name - ), - params={"param": random_param()}, - ) - assert len(results) == 1 + graph.graph_name + ), + params={"param": random_param()}, + ) + assert len(results) == 1 - edge_properties = edge0.properties - edge_properties.update(edge1.properties) - missing_properties = set(edge_properties.keys()).difference( - set(results[0]["dynamic_properties"].keys()).union( - set(results[0]["static_properties"].keys()) - ) + edge_properties = edge0.properties + edge_properties.update(edge1.properties) + missing_properties = set(edge_properties.keys()).difference( + set(results[0]["dynamic_properties"].keys()).union( + set(results[0]["static_properties"].keys()) ) - print(edge0.properties) - print(edge1.properties) - print(results) - assert ( - len(missing_properties) == 0 - ), "Missing properties of edge: {}".format(missing_properties) - - finally: - print("Clean up graph with name `{}`".format(graph_name)) - graph.cleanup() + ) + assert len(missing_properties) == 0, "Missing properties of edge: {}".format( + missing_properties + ) @pytest.mark.parametrize("use_flexible_schema", [False, True]) - def test_spanner_graph_avoid_unnecessary_overwrite(self, use_flexible_schema): - suffix = random_string(num_char=5, exclude_whitespaces=True) - graph_name = "test_graph{}".format(suffix) - graph = SpannerGraphStore( - instance_id, - google_database, - graph_name, - client=Client(project=project_id), - use_flexible_schema=use_flexible_schema, - static_node_properties=["a", "b"], - static_edge_properties=["a", "b"], + def test_spanner_graph_avoid_unnecessary_overwrite( + self, + setup_graph, + use_flexible_schema, + ): + suffix, graph = setup_graph + node0 = Node( + id=random_string(), + type="Node{}".format(suffix), + properties={"a": 1, "b": 1}, + ) + node1 = Node( + id=random_string(), + type="Node{}".format(suffix), + properties={"a": 1, "b": 1}, + ) + edge0 = Relationship( + source=node0, + target=node1, + type="Edge{}".format(suffix), + properties={"a": 1, "b": 1}, + ) + doc = GraphDocument( + nodes=[node0, node1], + relationships=[edge0], + source=Document( + page_content="Hello, world!", + metadata={"source": "https://example.com"}, + ), ) - graph.refresh_schema() - - try: - node0 = Node( - id=random_string(), - type="Node{}".format(suffix), - properties={"a": 1, "b": 1}, - ) - node1 = Node( - id=random_string(), - type="Node{}".format(suffix), - properties={"a": 1, "b": 1}, - ) - edge0 = Relationship( - source=node0, - target=node1, - type="Edge{}".format(suffix), - properties={"a": 1, "b": 1}, - ) - doc = GraphDocument( - nodes=[node0, node1], - relationships=[edge0], - source=Document( - page_content="Hello, world!", - metadata={"source": "https://example.com"}, - ), - ) - query = """GRAPH {} + query = """GRAPH {} MATCH (n {{id: @nodeId}}) LET properties = TO_JSON(n)['properties'] RETURN int64(properties.a) AS a, int64(properties.b) AS b @@ -356,50 +322,35 @@ def test_spanner_graph_avoid_unnecessary_overwrite(self, use_flexible_schema): LET properties = TO_JSON(e)['properties'] RETURN int64(properties.a) AS a, int64(properties.b) AS b """.format( - graph_name - ) - graph.add_graph_documents([doc]) - - # Test initial value: a=1, b=1 - results = graph.query(query, {"nodeId": node0.id}) - assert len(results) == 2, "Actual results: {}".format(results) - assert all((r["a"] == 1 for r in results)), "Actual results: {}".format( - results - ) - assert all((r["b"] == 1 for r in results)), "Actual results: {}".format( - results - ) - - node0.properties["a"] = 2 - edge0.properties["a"] = 2 - graph.add_graph_documents([doc]) - - # Test value after first overwrite: a=2, b=1 - results = graph.query(query, {"nodeId": node0.id}) - assert len(results) == 2, "Actual results: {}".format(results) - assert all((r["a"] == 2 for r in results)), "Actual results: {}".format( - results - ) - assert all((r["b"] == 1 for r in results)), "Actual results: {}".format( - results - ) - - node0.properties = {} - edge0.properties = {} - graph.add_graph_documents([doc]) - - # Test value after second overwrite: a=2, b=1 - results = graph.query(query, {"nodeId": node0.id}) - assert len(results) == 2, "Actual results: {}".format(results) - assert all((r["a"] == 2 for r in results)), "Actual results: {}".format( - results - ) - assert all((r["b"] == 1 for r in results)), "Actual results: {}".format( - results - ) - finally: - print("Clean up graph with name `{}`".format(graph_name)) - graph.cleanup() + graph.graph_name + ) + graph.add_graph_documents([doc]) + + # Test initial value: a=1, b=1 + results = graph.query(query, {"nodeId": node0.id}) + assert len(results) == 2, "Actual results: {}".format(results) + assert all((r["a"] == 1 for r in results)), "Actual results: {}".format(results) + assert all((r["b"] == 1 for r in results)), "Actual results: {}".format(results) + + node0.properties["a"] = 2 + edge0.properties["a"] = 2 + graph.add_graph_documents([doc]) + + # Test value after first overwrite: a=2, b=1 + results = graph.query(query, {"nodeId": node0.id}) + assert len(results) == 2, "Actual results: {}".format(results) + assert all((r["a"] == 2 for r in results)), "Actual results: {}".format(results) + assert all((r["b"] == 1 for r in results)), "Actual results: {}".format(results) + + node0.properties = {} + edge0.properties = {} + graph.add_graph_documents([doc]) + + # Test value after second overwrite: a=2, b=1 + results = graph.query(query, {"nodeId": node0.id}) + assert len(results) == 2, "Actual results: {}".format(results) + assert all((r["a"] == 2 for r in results)), "Actual results: {}".format(results) + assert all((r["b"] == 1 for r in results)), "Actual results: {}".format(results) @pytest.mark.parametrize( "graph_name, raises_exception", @@ -435,36 +386,31 @@ def test_spanner_graph_invalid_graph_name(self, graph_name, raises_exception): ) @pytest.mark.parametrize("use_flexible_schema", [False, True]) - def test_spanner_graph_with_existing_graph(self, use_flexible_schema): - suffix = random_string(num_char=5, exclude_whitespaces=True) - graph_name = "test_graph{}".format(suffix) + def test_spanner_graph_with_existing_graph( + self, + setup_graph, + use_flexible_schema, + ): + suffix, graph = setup_graph + graph_name = graph.graph_name node_table_name = "{}_node".format(graph_name) edge_table_name = "{}_edge".format(graph_name) - graph = SpannerGraphStore( - instance_id, - google_database, - graph_name, - client=Client(project=project_id), - use_flexible_schema=use_flexible_schema, - ) - graph.refresh_schema() - try: - graph.impl.apply_ddls( - [ - f""" + graph.impl.apply_ddls( + [ + f""" CREATE TABLE IF NOT EXISTS {node_table_name} ( id INT64 NOT NULL, str STRING(MAX), token TOKENLIST AS (TOKENIZE_FULLTEXT(str)) HIDDEN, ) PRIMARY KEY (id) """, - f""" + f""" CREATE TABLE IF NOT EXISTS {edge_table_name} ( id INT64 NOT NULL, target_id INT64 NOT NULL, ) PRIMARY KEY (id, target_id) """, - f""" + f""" CREATE PROPERTY GRAPH IF NOT EXISTS {graph_name} NODE TABLES ( {node_table_name} AS NodeA @@ -487,39 +433,124 @@ def test_spanner_graph_with_existing_graph(self, use_flexible_schema): LABEL EdgeBA PROPERTIES(target_id AS node_a_id, id AS node_b_id), ) """, - ] - ) - graph.refresh_schema() - schema = json.loads(graph.get_schema) - edgeab = graph.schema.get_edge_schema("EdgeAB") - edgeba = graph.schema.get_edge_schema("EdgeBA") - assert (edgeab.source.node_name, edgeab.target.node_name) == ( - "NodeA", - "NodeB", - ) - assert (edgeba.source.node_name, edgeba.target.node_name) == ( - "NodeB", - "NodeA", - ) - # TOKENLIST-typed properties are ignored. - assert len(schema["Node properties per node label"]["Node"]) == 4, schema[ - "Node properties per node label" - ]["Node"] - assert len(schema["Node properties per node label"]["NodeA"]) == 3, schema[ - "Node properties per node label" - ]["NodeA"] - assert len(schema["Node properties per node label"]["NodeB"]) == 3, schema[ - "Node properties per node label" - ]["NodB"] - assert len(schema["Possible edges per label"]["EdgeAB"]) == 4, schema[ - "Possible edges per label" - ]["EdgeAB"] - assert len(schema["Possible edges per label"]["EdgeBA"]) == 4, schema[ - "Possible edges per label" - ]["EdgeBA"] - assert len(schema["Possible edges per label"]["Edge"]) == 8, schema[ - "Possible edges per label" - ]["Edge"] - finally: - print("Clean up graph with name `{}`".format(graph_name)) - graph.cleanup() + ] + ) + graph.refresh_schema() + schema = json.loads(graph.get_schema) + edgeab = graph.schema.get_edge_schema("EdgeAB") + edgeba = graph.schema.get_edge_schema("EdgeBA") + assert (edgeab.source.node_name, edgeab.target.node_name) == ( + "NodeA", + "NodeB", + ) + assert (edgeba.source.node_name, edgeba.target.node_name) == ( + "NodeB", + "NodeA", + ) + # TOKENLIST-typed properties are ignored. + assert schema["Node properties per node label"]["Node"] == [ + {"name": "id", "type": "INT64"}, + {"name": "node_b_id", "type": "INT64"}, + {"name": "str", "type": "STRING"}, + ], "Invalid Node properties" + assert schema["Node properties per node label"]["NodeA"] == [ + {"name": "id", "type": "INT64"}, + {"name": "node_a_id", "type": "INT64"}, + {"name": "str", "type": "STRING"}, + ], "Invalid NodeA properties" + assert schema["Node properties per node label"]["NodeB"] == [ + {"name": "id", "type": "INT64"}, + {"name": "node_b_id", "type": "INT64"}, + {"name": "str", "type": "STRING"}, + ], "Invalid NodeB properties" + assert schema["Possible edges per label"]["EdgeAB"] == [ + "(:Node) -[:EdgeAB]-> (:Node)", + "(:Node) -[:EdgeAB]-> (:NodeB)", + "(:NodeA) -[:EdgeAB]-> (:Node)", + "(:NodeA) -[:EdgeAB]-> (:NodeB)", + ], "Invalid EdgeAB patterns" + assert schema["Possible edges per label"]["EdgeBA"] == [ + "(:Node) -[:EdgeBA]-> (:Node)", + "(:Node) -[:EdgeBA]-> (:NodeA)", + "(:NodeB) -[:EdgeBA]-> (:Node)", + "(:NodeB) -[:EdgeBA]-> (:NodeA)", + ], "Invalid EdgeBA patterns" + assert schema["Possible edges per label"]["Edge"] == [ + "(:Node) -[:Edge]-> (:Node)", + "(:Node) -[:Edge]-> (:NodeA)", + "(:Node) -[:Edge]-> (:NodeB)", + "(:NodeA) -[:Edge]-> (:Node)", + "(:NodeA) -[:Edge]-> (:NodeB)", + "(:NodeB) -[:Edge]-> (:Node)", + "(:NodeB) -[:Edge]-> (:NodeA)", + ], "Invalid Edge patterns" + + @pytest.mark.parametrize("use_flexible_schema", [False, True]) + def test_spanner_graph_schema_representation( + self, + setup_graph, + use_flexible_schema, + ): + suffix, graph = setup_graph + node0 = Node( + id=random_string(), + type="Node0{}".format(suffix), + properties={"j0": random_int()}, + ) + node1 = Node( + id=random_string(), + type="Node1{}".format(suffix), + properties={"j1": random_string()}, + ) + edge = Relationship( + source=node0, target=node1, type="Links", properties={"j": random_json()} + ) + + doc = GraphDocument( + nodes=[node0, node1], + relationships=[edge], + source=Document( + page_content="Hello, world!", + metadata={"source": "https://example.com"}, + ), + ) + graph.add_graph_documents([doc]) + schema = json.loads(graph.get_schema) + node0_json_fields = sorted( + [p["name"] for p in schema["Node properties per node label"][node0.type]] + ) + node1_json_fields = sorted( + [p["name"] for p in schema["Node properties per node label"][node1.type]] + ) + edge_json_fields = sorted( + [ + p["name"] + for edge in schema["Edge properties per edge label"].values() + for p in edge + ] + ) + edge_patterns = sorted( + [ + pattern + for edge in schema["Possible edges per label"].values() + for pattern in edge + ] + ) + if use_flexible_schema: + assert node0_json_fields == ["id", "j0", "label", "properties"] + assert node1_json_fields == ["id", "j1", "label", "properties"] + assert edge_json_fields == ["id", "j", "label", "properties", "target_id"] + assert edge_patterns == [ + "(:{src}) -[:{edge}]-> (:{dst})".format( + src=node0.type, edge=edge.type, dst=node1.type + ) + ] + else: + assert node0_json_fields == ["id", "j0"] + assert node1_json_fields == ["id", "j1"] + assert edge_json_fields == ["id", "j", "target_id"] + assert edge_patterns == [ + "(:{src}) -[:{src}_{edge}_{dst}]-> (:{dst})".format( + src=node0.type, edge=edge.type, dst=node1.type + ) + ]