Skip to content

Commit b258f4b

Browse files
committed
fix: lead relation generate hypos
1 parent f760d6f commit b258f4b

File tree

1 file changed

+43
-11
lines changed

1 file changed

+43
-11
lines changed

traincheck/invariant/lead_relation.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,22 @@ def generate_hypothesis(trace) -> list[Hypothesis]:
384384
event_A_idx = 0
385385
event_B_idx = 0
386386

387+
pre_event_A_idx = None
388+
pre_event_A_time = None
389+
387390
for event_A_pre in events_A_pre:
391+
invocation_id = event_A_pre["func_call_id"]
392+
event_A_post = trace.get_post_func_call_record(invocation_id)
393+
assert event_A_post is not None, "Post event not found"
394+
pre_event_A_idx = event_A_idx
395+
pre_event_A_time = event_A_post["time"]
396+
event_A_idx += 1
397+
break
398+
399+
assert pre_event_A_idx is not None
400+
assert pre_event_A_time is not None
401+
402+
for event_A_pre in events_A_pre[event_A_idx:]:
388403
invocation_id = event_A_pre["func_call_id"]
389404
example = Example()
390405
example.add_group(EXP_GROUP_NAME, [event_A_pre])
@@ -397,26 +412,41 @@ def generate_hypothesis(trace) -> list[Hypothesis]:
397412
assert event_A_post is not None, "Post event not found"
398413

399414
found_B_after_A = False
400-
while (
401-
event_B_idx < len(events_B_pre)
402-
and events_B_pre[event_B_idx]["time"] < event_A_post["time"]
403-
):
404-
event_B_idx += (
405-
1 # Skip B events that occurred before the current A event
406-
)
415+
# First A post time <= B pre time <= B post time <= next A pre time
416+
while event_B_idx < len(events_B_pre):
417+
event_B_pre = events_B_pre[event_B_idx]
418+
event_B_time = event_B_pre["time"]
419+
420+
if event_B_time > event_A_pre["time"]:
421+
break
422+
423+
if event_B_time <= pre_event_A_time:
424+
event_B_idx += 1
425+
continue
426+
427+
B_invocation_id = event_B_pre["func_call_id"]
428+
event_B_post = trace.get_post_func_call_record(B_invocation_id)
429+
assert event_B_post is not None, "Post event not found"
430+
if event_B_post["time"] > event_A_pre["time"]:
431+
event_B_idx += 1
432+
continue
407433

408-
if event_B_idx < len(events_B_pre):
409-
# Check if there's a B event after the current A event
410434
found_B_after_A = True
435+
event_B_idx += 1
436+
break
437+
438+
if found_B_after_A:
439+
# Check if there's a B event after the current A event
411440
hypothesis_with_examples[
412441
(func_A, func_B)
413442
].positive_examples.add_example(example)
414-
415-
if not found_B_after_A:
443+
else:
416444
hypothesis_with_examples[
417445
(func_A, func_B)
418446
].negative_examples.add_example(example)
419447

448+
pre_event_A_idx = event_A_idx
449+
pre_event_A_time = event_A_post["time"]
420450
event_A_idx += 1
421451
# add the rest of the A events as negative examples
422452
for event_A_pre in events_A_pre[event_A_idx:]:
@@ -430,6 +460,7 @@ def generate_hypothesis(trace) -> list[Hypothesis]:
430460

431461
return list(hypothesis_with_examples.values())
432462

463+
# TODO: fix
433464
@staticmethod
434465
def collect_examples(trace, hypothesis):
435466
"""Generate examples for a hypothesis on trace."""
@@ -691,6 +722,7 @@ def evaluate(value_group: list) -> bool:
691722
"""
692723
return True
693724

725+
# TODO: fix
694726
@staticmethod
695727
def static_check_all(
696728
trace: Trace, inv: Invariant, check_relation_first: bool

0 commit comments

Comments
 (0)