Skip to content

Commit 0738645

Browse files
fix(graphgen): move loss_strategy to config
1 parent bb6e0d1 commit 0738645

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

graphgen/operators/split_graph.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55

66
from models import NetworkXStorage, TraverseStrategy
77

8-
# TODO: move to config
9-
loss_strategy: str = "only_edge" # only_edge, both
10-
118
async def _get_node_info(
129
node_id: str,
1310
graph_storage: NetworkXStorage,
@@ -35,7 +32,8 @@ def _get_level_n_edges_by_max_width(
3532
max_depth: int,
3633
bidirectional: bool,
3734
max_extra_edges: int,
38-
edge_sampling: str
35+
edge_sampling: str,
36+
loss_strategy: str = "only_edge"
3937
) -> list:
4038
"""
4139
Get level n edges for an edge.
@@ -111,7 +109,8 @@ def _get_level_n_edges_by_max_tokens(
111109
max_depth: int,
112110
bidirectional: bool,
113111
max_tokens: int,
114-
edge_sampling: str
112+
edge_sampling: str,
113+
loss_strategy: str = "only_edge"
115114
) -> list:
116115
"""
117116
Get level n edges for an edge.
@@ -256,13 +255,13 @@ async def get_cached_node_info(node_id: str) -> dict:
256255
for i, (node_name, _) in enumerate(nodes):
257256
node_dict[node_name] = i
258257

259-
if loss_strategy == "both":
258+
if traverse_strategy.loss_strategy == "both":
260259
er_tuples = [([nodes[node_dict[edge[0]]], nodes[node_dict[edge[1]]]], edge) for edge in edges]
261260
edges = _sort_tuples(er_tuples, edge_sampling)
262-
elif loss_strategy == "only_edge":
261+
elif traverse_strategy.loss_strategy == "only_edge":
263262
edges = _sort_edges(edges, edge_sampling)
264263
else:
265-
raise ValueError(f"Invalid loss strategy: {loss_strategy}")
264+
raise ValueError(f"Invalid loss strategy: {traverse_strategy.loss_strategy}")
266265

267266
for i, (src, tgt, _) in enumerate(edges):
268267
edge_adj_list[src].append(i)
@@ -288,13 +287,13 @@ async def get_cached_node_info(node_id: str) -> dict:
288287
level_n_edges = _get_level_n_edges_by_max_width(
289288
edge_adj_list, node_dict, edges, nodes, edge, max_depth,
290289
traverse_strategy.bidirectional, traverse_strategy.max_extra_edges,
291-
edge_sampling
290+
edge_sampling, traverse_strategy.loss_strategy
292291
)
293292
else:
294293
level_n_edges = _get_level_n_edges_by_max_tokens(
295294
edge_adj_list, node_dict, edges, nodes, edge, max_depth,
296295
traverse_strategy.bidirectional, traverse_strategy.max_tokens,
297-
edge_sampling
296+
edge_sampling, traverse_strategy.loss_strategy
298297
)
299298

300299
for _edge in level_n_edges:

0 commit comments

Comments
 (0)