55
66from models import NetworkXStorage , TraverseStrategy
77
8- # TODO: move to config
9- loss_strategy : str = "only_edge" # only_edge, both
10-
118async def _get_node_info (
129 node_id : str ,
1310 graph_storage : NetworkXStorage ,
@@ -35,7 +32,8 @@ def _get_level_n_edges_by_max_width(
3532 max_depth : int ,
3633 bidirectional : bool ,
3734 max_extra_edges : int ,
38- edge_sampling : str
35+ edge_sampling : str ,
36+ loss_strategy : str = "only_edge"
3937) -> list :
4038 """
4139 Get level n edges for an edge.
@@ -111,7 +109,8 @@ def _get_level_n_edges_by_max_tokens(
111109 max_depth : int ,
112110 bidirectional : bool ,
113111 max_tokens : int ,
114- edge_sampling : str
112+ edge_sampling : str ,
113+ loss_strategy : str = "only_edge"
115114) -> list :
116115 """
117116 Get level n edges for an edge.
@@ -256,13 +255,13 @@ async def get_cached_node_info(node_id: str) -> dict:
256255 for i , (node_name , _ ) in enumerate (nodes ):
257256 node_dict [node_name ] = i
258257
259- if loss_strategy == "both" :
258+ if traverse_strategy . loss_strategy == "both" :
260259 er_tuples = [([nodes [node_dict [edge [0 ]]], nodes [node_dict [edge [1 ]]]], edge ) for edge in edges ]
261260 edges = _sort_tuples (er_tuples , edge_sampling )
262- elif loss_strategy == "only_edge" :
261+ elif traverse_strategy . loss_strategy == "only_edge" :
263262 edges = _sort_edges (edges , edge_sampling )
264263 else :
265- raise ValueError (f"Invalid loss strategy: { loss_strategy } " )
264+ raise ValueError (f"Invalid loss strategy: { traverse_strategy . loss_strategy } " )
266265
267266 for i , (src , tgt , _ ) in enumerate (edges ):
268267 edge_adj_list [src ].append (i )
@@ -288,13 +287,13 @@ async def get_cached_node_info(node_id: str) -> dict:
288287 level_n_edges = _get_level_n_edges_by_max_width (
289288 edge_adj_list , node_dict , edges , nodes , edge , max_depth ,
290289 traverse_strategy .bidirectional , traverse_strategy .max_extra_edges ,
291- edge_sampling
290+ edge_sampling , traverse_strategy . loss_strategy
292291 )
293292 else :
294293 level_n_edges = _get_level_n_edges_by_max_tokens (
295294 edge_adj_list , node_dict , edges , nodes , edge , max_depth ,
296295 traverse_strategy .bidirectional , traverse_strategy .max_tokens ,
297- edge_sampling
296+ edge_sampling , traverse_strategy . loss_strategy
298297 )
299298
300299 for _edge in level_n_edges :
0 commit comments