Skip to content

Commit 31cd48f

Browse files
committed
wip: narrow deep paths now creates a chain of var edges (test included)
1 parent e2edcd0 commit 31cd48f

File tree

2 files changed

+51
-7
lines changed

2 files changed

+51
-7
lines changed

gp_learner.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -413,19 +413,49 @@ def mutate_del_triple(child):
413413
return new_child
414414

415415

416-
def mutate_expand_node(child, pb_en_out_link):
417-
# TODO: can maybe be improved by sparqling
418-
nodes = list(child.nodes)
419-
node = random.choice(nodes)
416+
def _mutate_expand_node_helper(node, pb_en_out_link=config.MUTPB_EN_OUT_LINK):
420417
var_edge = gen_random_var()
421418
var_node = gen_random_var()
422419
if random.random() < pb_en_out_link:
423420
new_triple = (node, var_edge, var_node)
424421
else:
425422
new_triple = (var_node, var_edge, node)
423+
return new_triple, var_node
424+
425+
426+
def mutate_expand_node(child, node=None):
427+
# TODO: can maybe be improved by sparqling
428+
if not node:
429+
nodes = list(child.nodes)
430+
node = random.choice(nodes)
431+
new_triple, _ = _mutate_expand_node_helper(node)
426432
return child + (new_triple,)
427433

428434

435+
def mutate_deep_narrow_path(
436+
child,
437+
min_len=config.MUTPB_DN_MIN_LEN,
438+
max_len=config.MUTPB_DN_MAX_LEN,
439+
term_pb=config.MUTPB_DN_TERM_PB,
440+
):
441+
assert isinstance(child, GraphPattern)
442+
nodes = list(child.nodes)
443+
start_node = random.choice(nodes)
444+
# target_nodes = set(nodes) - {start_node}
445+
gp = child
446+
hop = 0
447+
while True:
448+
if hop >= min_len and random.random() < term_pb:
449+
break
450+
if hop >= max_len:
451+
break
452+
hop += 1
453+
new_triple, var_node = _mutate_expand_node_helper(start_node)
454+
gp += [new_triple]
455+
start_node = var_node
456+
return gp
457+
458+
429459
def mutate_add_edge(child):
430460
# TODO: can maybe be improved by sparqling
431461
nodes = list(child.nodes)
@@ -647,7 +677,6 @@ def mutate(
647677
pb_ae=config.MUTPB_AE,
648678
pb_dt=config.MUTPB_DT,
649679
pb_en=config.MUTPB_EN,
650-
pb_en_out_link=config.MUTPB_EN_OUT_LINK,
651680
pb_fv=config.MUTPB_FV,
652681
pb_id=config.MUTPB_ID,
653682
pb_iv=config.MUTPB_IV,
@@ -678,7 +707,7 @@ def mutate(
678707
child = mutate_del_triple(child)
679708

680709
if random.random() < pb_en:
681-
child = mutate_expand_node(child, pb_en_out_link)
710+
child = mutate_expand_node(child)
682711
if random.random() < pb_ae:
683712
child = mutate_add_edge(child)
684713

@@ -694,7 +723,6 @@ def mutate(
694723
else:
695724
children = [child]
696725

697-
698726
# TODO: deep & narrow paths mutation
699727

700728
children = {

tests/test_gp_learner_offline.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from gp_learner import mutate_increase_dist
1414
from gp_learner import mutate_merge_var
1515
from gp_learner import mutate_simplify_pattern
16+
from gp_learner import mutate_deep_narrow_path
1617
from graph_pattern import GraphPattern
1718
from graph_pattern import SOURCE_VAR
1819
from graph_pattern import TARGET_VAR
@@ -108,6 +109,17 @@ def test_mutate_merge_var():
108109
assert False, "merge never reached one of the cases: %s" % cases
109110

110111

112+
def test_mutate_deep_narrow_path():
113+
p = Variable('p')
114+
gp = GraphPattern([
115+
(SOURCE_VAR, p, TARGET_VAR)
116+
])
117+
child = mutate_deep_narrow_path(gp)
118+
assert gp == child or len(child) > len(gp)
119+
print(gp)
120+
print(child)
121+
122+
111123
def test_simplify_pattern():
112124
gp = GraphPattern([(SOURCE_VAR, wikilink, TARGET_VAR)])
113125
res = mutate_simplify_pattern(gp)
@@ -271,3 +283,7 @@ def test_remaining_gain_sample_gtps():
271283

272284
def test_gtp_scores():
273285
assert gtp_scores - gtp_scores == 0
286+
287+
288+
if __name__ == '__main__':
289+
test_mutate_deep_narrow_path()

0 commit comments

Comments
 (0)