diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index ed7cbda..bdcb5a2 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -204,7 +204,7 @@ def __init__( self.__set_db(db) if all([self.__db, name]): - self.__set_graph(name, default_node_type, edge_type_func) + self.__set_graph(name, overwrite_graph, default_node_type, edge_type_func) self.__set_edge_collections_attributes(edge_collections_attributes) # NOTE: Need to revisit these... @@ -232,23 +232,6 @@ def __init__( self._set_factory_methods(read_parallelism, read_batch_size) self.__set_arangodb_backend_config() - if overwrite_graph: - logger.info("Overwriting graph...") - - properties = self.adb_graph.properties() - self.db.delete_graph(name, drop_collections=True) - self.db.create_graph( - name=name, - edge_definitions=properties["edge_definitions"], - orphan_collections=properties["orphan_collections"], - smart=properties.get("smart"), - disjoint=properties.get("disjoint"), - smart_field=properties.get("smart_field"), - shard_count=properties.get("shard_count"), - replication_factor=properties.get("replication_factor"), - write_concern=properties.get("write_concern"), - ) - if isinstance(incoming_graph_data, nx.Graph): self._load_nx_graph(incoming_graph_data, write_batch_size, write_async) self._loaded_incoming_graph_data = True @@ -367,13 +350,33 @@ def __set_db(self, db: Any = None) -> None: def __set_graph( self, name: Any, + overwrite_graph: bool, default_node_type: str | None = None, edge_type_func: Callable[[str, str], str] | None = None, ) -> None: if not isinstance(name, str): raise TypeError("**name** must be a string") - if self.db.has_graph(name): + graph_exists = self.db.has_graph(name) + + if graph_exists and overwrite_graph: + logger.info(f"Overwriting graph '{name}'") + + properties = self.db.graph(name).properties() + self.db.delete_graph(name, drop_collections=True) + self.db.create_graph( + name=name, + edge_definitions=properties["edge_definitions"], + orphan_collections=properties["orphan_collections"], + smart=properties.get("smart"), + disjoint=properties.get("disjoint"), + smart_field=properties.get("smart_field"), + shard_count=properties.get("shard_count"), + replication_factor=properties.get("replication_factor"), + write_concern=properties.get("write_concern"), + ) + + if graph_exists: logger.info(f"Graph '{name}' exists.") if edge_type_func is not None: @@ -613,9 +616,14 @@ def chat( if llm is None: llm = ChatOpenAI(temperature=0, model_name="gpt-4") + graph = ArangoGraph( + self.db, + # graph_name=self.name # not yet supported + ) + chain = ArangoGraphQAChain.from_llm( llm=llm, - graph=ArangoGraph(self.db), + graph=graph, verbose=verbose, )