Skip to content

Commit bb6e0d1

Browse files
refactor(configs): add loss_strategy in traverse_strategy
1 parent d14c5dc commit bb6e0d1

File tree

3 files changed

+15
-13
lines changed

3 files changed

+15
-13
lines changed

configs/config.yaml.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ traverse_strategy:
1515
max_depth: 2
1616
max_extra_edges: 5
1717
max_tokens: 256
18+
loss_strategy: only_edge
1819
web_search: false

configs/graphgen_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ traverse_strategy:
1515
max_depth: 1
1616
max_extra_edges: 2
1717
max_tokens: 256
18+
loss_strategy: only_edge
1819
web_search: false

graphgen/operators/traverse_graph.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
from utils import detect_main_language, compute_content_hash, logger
88
from graphgen.operators.split_graph import get_batches_with_strategy
99

10-
# TODO: move to config
11-
# TODO: if add isolated nodes, the loss strategy should be changed to "both"
12-
loss_strategy: str = "only_edge" # only_edge, both
1310

1411
async def _pre_tokenize(graph_storage: NetworkXStorage,
1512
tokenizer: Tokenizer,
@@ -101,22 +98,23 @@ def get_loss_tercile(losses: list) -> (float, float):
10198

10299
return losses[q1_index], losses[q2_index]
103100

104-
def assign_difficulty(subgraphs: list, difficulty_order: list) -> list:
101+
def assign_difficulty(subgraphs: list, difficulty_order: list, loss_strategy: str) -> list:
105102
"""
106-
Assign difficulty to subgraphs based on the loss
103+
Assign difficulty to subgraphs based on the loss.
107104
108105
:param subgraphs
109106
:param difficulty_order
107+
:param loss_strategy
110108
:return
111109
"""
112110
losses = []
113111
for subgraph in subgraphs:
114-
loss = get_average_loss(subgraph)
112+
loss = get_average_loss(subgraph, loss_strategy)
115113
losses.append(loss)
116114
q1, q2 = get_loss_tercile(losses)
117115

118116
for i, subgraph in enumerate(subgraphs):
119-
loss = get_average_loss(subgraph)
117+
loss = get_average_loss(subgraph, loss_strategy)
120118
if loss < q1:
121119
# easy
122120
subgraphs[i] = (subgraph[0], subgraph[1], difficulty_order[0])
@@ -128,7 +126,7 @@ def assign_difficulty(subgraphs: list, difficulty_order: list) -> list:
128126
subgraphs[i] = (subgraph[0], subgraph[1], difficulty_order[2])
129127
return subgraphs
130128

131-
def get_average_loss(batch: tuple) -> float:
129+
def get_average_loss(batch: tuple, loss_strategy: str) -> float:
132130
if loss_strategy == "only_edge":
133131
return sum(edge[2]['loss'] for edge in batch[1]) / len(batch[1])
134132
if loss_strategy == "both":
@@ -242,7 +240,7 @@ async def _process_single_batch(
242240
compute_content_hash(context): {
243241
"question": question,
244242
"answer": context,
245-
"loss": get_average_loss(_process_batch),
243+
"loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy),
246244
"difficulty": _process_batch[2],
247245
}
248246
}
@@ -268,7 +266,7 @@ async def _process_single_batch(
268266
final_results[compute_content_hash(qa['question'])] = {
269267
"question": qa['question'],
270268
"answer": qa['answer'],
271-
"loss": get_average_loss(_process_batch),
269+
"loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy),
272270
"difficulty": _process_batch[2],
273271
}
274272
return final_results
@@ -286,7 +284,8 @@ async def _process_single_batch(
286284
traverse_strategy
287285
)
288286

289-
processing_batches = assign_difficulty(processing_batches, traverse_strategy.difficulty_order)
287+
processing_batches = assign_difficulty(processing_batches, traverse_strategy.difficulty_order,
288+
traverse_strategy.loss_strategy)
290289

291290
for result in tqdm_async(asyncio.as_completed(
292291
[_process_single_batch(batch) for batch in processing_batches]
@@ -436,7 +435,8 @@ async def traverse_graph_for_multi_hop(
436435
traverse_strategy
437436
)
438437

439-
processing_batches = assign_difficulty(processing_batches, traverse_strategy.difficulty_order)
438+
processing_batches = assign_difficulty(processing_batches, traverse_strategy.difficulty_order,
439+
traverse_strategy.loss_strategy)
440440

441441
async def _process_single_batch(
442442
_process_batch: tuple
@@ -487,7 +487,7 @@ async def _process_single_batch(
487487
compute_content_hash(question): {
488488
"question": question,
489489
"answer": answer,
490-
"loss": get_average_loss(_process_batch),
490+
"loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy),
491491
"difficulty": _process_batch[2],
492492
}
493493
}

0 commit comments

Comments
 (0)