@@ -204,7 +204,7 @@ def __init__(
204204
205205 self .__set_db (db )
206206 if all ([self .__db , name ]):
207- self .__set_graph (name , default_node_type , edge_type_func )
207+ self .__set_graph (name , overwrite_graph , default_node_type , edge_type_func )
208208 self .__set_edge_collections_attributes (edge_collections_attributes )
209209
210210 # NOTE: Need to revisit these...
@@ -232,23 +232,6 @@ def __init__(
232232 self ._set_factory_methods (read_parallelism , read_batch_size )
233233 self .__set_arangodb_backend_config ()
234234
235- if overwrite_graph :
236- logger .info ("Overwriting graph..." )
237-
238- properties = self .adb_graph .properties ()
239- self .db .delete_graph (name , drop_collections = True )
240- self .db .create_graph (
241- name = name ,
242- edge_definitions = properties ["edge_definitions" ],
243- orphan_collections = properties ["orphan_collections" ],
244- smart = properties .get ("smart" ),
245- disjoint = properties .get ("disjoint" ),
246- smart_field = properties .get ("smart_field" ),
247- shard_count = properties .get ("shard_count" ),
248- replication_factor = properties .get ("replication_factor" ),
249- write_concern = properties .get ("write_concern" ),
250- )
251-
252235 if isinstance (incoming_graph_data , nx .Graph ):
253236 self ._load_nx_graph (incoming_graph_data , write_batch_size , write_async )
254237 self ._loaded_incoming_graph_data = True
@@ -367,13 +350,33 @@ def __set_db(self, db: Any = None) -> None:
367350 def __set_graph (
368351 self ,
369352 name : Any ,
353+ overwrite_graph : bool ,
370354 default_node_type : str | None = None ,
371355 edge_type_func : Callable [[str , str ], str ] | None = None ,
372356 ) -> None :
373357 if not isinstance (name , str ):
374358 raise TypeError ("**name** must be a string" )
375359
376- if self .db .has_graph (name ):
360+ graph_exists = self .db .has_graph (name )
361+
362+ if graph_exists and overwrite_graph :
363+ logger .info (f"Overwriting graph '{ name } '" )
364+
365+ properties = self .db .graph (name ).properties ()
366+ self .db .delete_graph (name , drop_collections = True )
367+ self .db .create_graph (
368+ name = name ,
369+ edge_definitions = properties ["edge_definitions" ],
370+ orphan_collections = properties ["orphan_collections" ],
371+ smart = properties .get ("smart" ),
372+ disjoint = properties .get ("disjoint" ),
373+ smart_field = properties .get ("smart_field" ),
374+ shard_count = properties .get ("shard_count" ),
375+ replication_factor = properties .get ("replication_factor" ),
376+ write_concern = properties .get ("write_concern" ),
377+ )
378+
379+ if graph_exists :
377380 logger .info (f"Graph '{ name } ' exists." )
378381
379382 if edge_type_func is not None :
@@ -613,9 +616,14 @@ def chat(
613616 if llm is None :
614617 llm = ChatOpenAI (temperature = 0 , model_name = "gpt-4" )
615618
619+ graph = ArangoGraph (
620+ self .db ,
621+ # graph_name=self.name # not yet supported
622+ )
623+
616624 chain = ArangoGraphQAChain .from_llm (
617625 llm = llm ,
618- graph = ArangoGraph ( self . db ) ,
626+ graph = graph ,
619627 verbose = verbose ,
620628 )
621629
0 commit comments