Skip to content

Commit 94f9ee2

Browse files
committed
mutate_deep_narrow_path() is expanding the graph pattern, and fixing suitable edges
1 parent b8e52bf commit 94f9ee2

File tree

3 files changed

+116
-100
lines changed

3 files changed

+116
-100
lines changed

gp_learner.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def mutate_expand_node(child, node=None):
429429
if not node:
430430
nodes = list(child.nodes)
431431
node = random.choice(nodes)
432-
new_triple, _ = _mutate_expand_node_helper(node)
432+
new_triple, _, _ = _mutate_expand_node_helper(node)
433433
return child + (new_triple,)
434434

435435

@@ -452,15 +452,10 @@ def mutate_deep_narrow_path(
452452
break
453453
hop += 1
454454
new_triple, var_node, var_edge = _mutate_expand_node_helper(start_node)
455-
test_gp = gp + [new_triple]
456-
test_gp, fixed = _mutate_deep_narrow_path_helper(
457-
sparql, timeout, gtp_scores, test_gp, var_edge, var_node)
458-
if fixed:
459-
start_node = var_node
460-
gp = test_gp
461-
462-
# TODO: insert connection to a target node
463-
# TODO: fix edge or node ( to_count_var_over_values_query)
455+
gp += [new_triple]
456+
gp, fixed = _mutate_deep_narrow_path_helper(
457+
sparql, timeout, gtp_scores, gp, var_edge, var_node)
458+
start_node = var_node
464459
return gp
465460

466461

@@ -489,36 +484,52 @@ def _mutate_deep_narrow_path_helper(
489484
t, substitution_counts = variable_substitution_deep_narrow_mut_query(
490485
sparql, timeout, child, edge_var, node_var, ground_truth_pairs,
491486
limit_res)
492-
if not substitution_counts:
487+
edge_count, node_sum_count = substitution_counts
488+
if not node_sum_count:
493489
# the current pattern is unfit, as we can't find anything fulfilling it
494490
logger.debug("tried to fix a var %s without result:\n%s"
495-
"seems as if the pattern can't be fulfilled!",
496-
edge_var, child.to_sparql_select_query())
491+
"seems as if the pattern can't be fulfilled!",
492+
edge_var, child.to_sparql_select_query())
497493
fixed = False
498-
return [child], fixed
499-
mutate_fix_var_filter(substitution_counts)
500-
if not substitution_counts:
494+
return child, fixed
495+
mutate_fix_var_filter(node_sum_count)
496+
mutate_fix_var_filter(edge_count)
497+
if not node_sum_count:
501498
# could have happened that we removed the only possible substitution
502499
fixed = False
503-
return [child], fixed
500+
return child, fixed
501+
502+
prio = Counter()
503+
for edge, node_sum in node_sum_count.items():
504+
ec = edge_count[edge]
505+
prio[edge] = ec / (node_sum / ec) # ec / AVG degree
504506
# randomly pick n of the substitutions with a prob ~ to their counts
505-
items, counts = zip(*substitution_counts.most_common())
506-
substs = sample_from_list(items, counts, sample_n)
507+
edges, prios = zip(*prio.most_common())
508+
509+
substs = sample_from_list(edges, prios, sample_n)
510+
507511
logger.info(
508-
'fixed variable %s in %sto:\n %s\n<%d out of:\n%s\n',
509-
edge_var.n3(),
510-
child,
511-
'\n '.join([subst.n3() for subst in substs]),
512-
sample_n,
513-
'\n'.join([' %d: %s' % (c, v.n3())
514-
for v, c in substitution_counts.most_common()]),
512+
'fixed variable %s in %sto:\n %s\n<%d out of:\n%s\n',
513+
edge_var.n3(),
514+
child,
515+
'\n '.join([subst.n3() for subst in substs]),
516+
sample_n,
517+
'\n'.join([
518+
' %.3f: %s' % (c, v.n3()) for v, c in prio.most_common()]),
515519
)
516520
fixed = True
517-
res = [
521+
orig_child = child
522+
children = [
518523
GraphPattern(child, mapping={edge_var: subst})
519524
for subst in substs
520525
]
521-
return res, fixed
526+
children = [
527+
c if fit_to_live(c) else orig_child
528+
for c in children
529+
]
530+
if children:
531+
child = random.choice(list(children))
532+
return child, fixed
522533

523534

524535
def mutate_add_edge(child):

gp_query.py

Lines changed: 56 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from graph_pattern import TARGET_VAR
3333
from graph_pattern import ASK_VAR
3434
from graph_pattern import COUNT_VAR
35+
from graph_pattern import NODE_VAR_SUM
36+
from graph_pattern import EDGE_VAR_COUNT
3537
from utils import chunker
3638
from utils import exception_stack_catcher
3739
from utils import get_path
@@ -282,7 +284,6 @@ def _combined_chunk_res(q_res, _vars, _ret_val_mapping):
282284
return chunk_res
283285

284286

285-
286287
def count_query(sparql, timeout, graph_pattern, source=None,
287288
**kwds):
288289
assert isinstance(graph_pattern, GraphPattern)
@@ -426,21 +427,6 @@ def variable_substitution_query(
426427
)
427428

428429

429-
def variable_substitution_deep_narrow_mut_query(
430-
sparql, timeout, graph_pattern, edge_var, node_var,
431-
source_target_pairs, limit_res, batch_size=config.BATCH_SIZE):
432-
_vars, _values, _ret_val_mapping = _get_vars_values_mapping(
433-
graph_pattern, source_target_pairs)
434-
_edge_var_node_var_and_vars = (edge_var, node_var, _vars)
435-
return _multi_query(
436-
sparql, timeout, graph_pattern, source_target_pairs, batch_size,
437-
_edge_var_node_var_and_vars, _values, _ret_val_mapping,
438-
_var_subst_res_init, _var_subst_dnp_chunk_q,
439-
_var_subst_dnp_chunk_result_ext, limit=limit_res,
440-
# non standard, passed via **kwds, see handling below
441-
)
442-
443-
444430
# noinspection PyUnusedLocal
445431
def _var_subst_res_init(_, **kwds):
446432
return Counter()
@@ -455,17 +441,6 @@ def _var_subst_chunk_q(gp, _sel_var_and_vars, values_chunk, limit):
455441
limit=limit)
456442

457443

458-
def _var_subst_dnp_chunk_q(gp, _edge_var_node_var_and_vars,
459-
values_chunk, limit):
460-
edge_var, node_var, _vars = _edge_var_node_var_and_vars
461-
return gp.to_find_edge_var_for_narrow_path_query(
462-
edge_var=edge_var,
463-
node_var=node_var,
464-
vars_=_vars,
465-
values={_vars: values_chunk},
466-
limit_res=limit)
467-
468-
469444
# noinspection PyUnusedLocal
470445
def _var_subst_chunk_result_ext(q_res, _sel_var_and_vars, _, **kwds):
471446
var, _vars = _sel_var_and_vars
@@ -482,23 +457,70 @@ def _var_subst_chunk_result_ext(q_res, _sel_var_and_vars, _, **kwds):
482457
return chunk_res
483458

484459

485-
def _var_subst_dnp_chunk_result_ext(q_res, _edge_var_node_var_and_vars, _, **kwds):
460+
def _var_subst_res_update(res, update, **_):
461+
res += update
462+
463+
464+
def variable_substitution_deep_narrow_mut_query(
465+
sparql, timeout, graph_pattern, edge_var, node_var,
466+
source_target_pairs, limit_res, batch_size=config.BATCH_SIZE):
467+
_vars, _values, _ret_val_mapping = _get_vars_values_mapping(
468+
graph_pattern, source_target_pairs)
469+
_edge_var_node_var_and_vars = (edge_var, node_var, _vars)
470+
return _multi_query(
471+
sparql, timeout, graph_pattern, source_target_pairs, batch_size,
472+
_edge_var_node_var_and_vars, _values, _ret_val_mapping,
473+
_var_subst_dnp_res_init, _var_subst_dnp_chunk_q,
474+
_var_subst_dnp_chunk_result_ext,
475+
_res_update=_var_subst_dnp_update,
476+
limit=limit_res,
477+
# non standard, passed via **kwds, see handling below
478+
)
479+
480+
481+
# noinspection PyUnusedLocal
482+
def _var_subst_dnp_res_init(_, **kwds):
483+
return Counter(), Counter()
484+
485+
486+
def _var_subst_dnp_chunk_q(gp, _edge_var_node_var_and_vars,
487+
values_chunk, limit):
486488
edge_var, node_var, _vars = _edge_var_node_var_and_vars
487-
chunk_res = Counter()
489+
return gp.to_find_edge_var_for_narrow_path_query(
490+
edge_var=edge_var,
491+
node_var=node_var,
492+
vars_=_vars,
493+
values={_vars: values_chunk},
494+
limit_res=limit)
495+
496+
497+
# noinspection PyUnusedLocal
498+
def _var_subst_dnp_chunk_result_ext(
499+
q_res, _edge_var_node_var_and_vars, _, **kwds):
500+
edge_var, node_var, _vars = _edge_var_node_var_and_vars
501+
chunk_edge_count, chunk_node_sum = Counter(), Counter()
488502
res_rows_path = ['results', 'bindings']
489503
bindings = sparql_json_result_bindings_to_rdflib(
490504
get_path(q_res, res_rows_path, default=[])
491505
)
492506

493507
for row in bindings:
494508
row_res = get_path(row, [edge_var])
495-
count_res = int(get_path(row, [COUNT_VAR], '0'))
496-
chunk_res[row_res] += count_res
497-
return chunk_res
509+
edge_count = int(get_path(row, [EDGE_VAR_COUNT], '0'))
510+
chunk_edge_count[row_res] += edge_count
511+
node_sum_count = int(get_path(row, [NODE_VAR_SUM], '0'))
512+
chunk_node_sum[row_res] += node_sum_count
513+
return chunk_edge_count, chunk_node_sum,
498514

499515

500-
def _var_subst_res_update(res, update, **_):
501-
res += update
516+
def _var_subst_dnp_update(res, up, **_):
517+
edge_count, node_sum_count = res
518+
try:
519+
chunk_edge_count, chunk_node_sum = up
520+
edge_count.update(chunk_edge_count)
521+
node_sum_count.update(chunk_node_sum)
522+
except ValueError:
523+
pass
502524

503525

504526
def generate_stps_from_gp(sparql, gp):

graph_pattern.py

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@
4343
ASK_VAR = Variable('ask')
4444
COUNT_VAR = Variable('count')
4545
EDGE_VAR_COUNT = Variable('edge_var_count')
46-
NODE_VAR_COUNT = Variable('node_var_count')
47-
MAX_NODE_COUNT = Variable('max_node_count')
48-
PRIO_VAR = Variable('prio_var')
46+
NODE_VAR_SUM = Variable('node_var_sum')
4947

5048

5149
def gen_random_var():
@@ -581,7 +579,7 @@ def _sparql_query_pattern_part(
581579
res = self._sparql_values_part(values, indent)
582580
tres = []
583581
for s, p, o in self:
584-
tres.append('%s %s %s .' % (s.n3(), p.n3(), o.n3()))
582+
tres.append('%s %s %s.' % (s.n3(), p.n3(), o.n3()))
585583
res += indent + ('\n' + indent).join(tres) + '\n'
586584
if bind:
587585
res += '%sFILTER(\n' % indent
@@ -667,22 +665,7 @@ def to_count_var_over_values_query(self, var, vars_, values, limit):
667665
"""Counts possible fulfilling substitutions for var.
668666
669667
Meant to perform a query like this:
670-
SELECT ?var count(*) as ?count WHERE {
671-
VALUES (?source ?target) {
672-
(dbr:Adolescence dbr:Youth)
673-
(dbr:Adult dbr:Child)
674-
(dbr:Angel dbr:Heaven)
675-
(dbr:Arithmetic dbr:Mathematics)
676-
}
677-
{
678-
SELECT DISTINCT ?source ?target ?var WHERE {
679-
?source ?edge ?target .
680-
?var dbo:wikiPageWikiLink ?target .
681-
}
682-
}
683-
}
684-
ORDER BY desc(?count)
685-
LIMIT 10
668+
686669
687670
:param var: Variable to count over.
688671
:param vars_: List of vars to fix values for (e.g. ?source, ?target).
@@ -714,10 +697,11 @@ def to_count_var_over_values_query(self, var, vars_, values, limit):
714697
res += 'LIMIT %d\n' % limit
715698
return self._sparql_prefix(res)
716699

717-
def to_find_edge_var_for_narrow_path_query\
718-
(self, edge_var, node_var, vars_, values, limit_res,
719-
filter_node_count=config.MUTPB_DN_FILTER_NODE_COUNT,
720-
filter_edge_count=config.MUTPB_DN_FILTER_EDGE_COUNT):
700+
def to_find_edge_var_for_narrow_path_query(
701+
self, edge_var, node_var, vars_, values, limit_res,
702+
filter_node_count=config.MUTPB_DN_FILTER_NODE_COUNT,
703+
filter_edge_count=config.MUTPB_DN_FILTER_EDGE_COUNT,
704+
):
721705
"""Counts possible substitutions for edge_var to get a narrow path
722706
723707
Meant to perform a query like this:
@@ -763,14 +747,15 @@ def to_find_edge_var_for_narrow_path_query\
763747

764748
res = 'SELECT * WHERE {\n'
765749
res += ' {\n'\
766-
' SELECT %s (COUNT(*) AS %s) (MAX(%s) AS %s) ' \
767-
'(COUNT(*)/AVG(%s) AS %s) WHERE {\n' % (
768-
edge_var.n3(), EDGE_VAR_COUNT.n3(),
769-
NODE_VAR_COUNT.n3(), MAX_NODE_COUNT.n3(),
770-
NODE_VAR_COUNT.n3(), PRIO_VAR.n3())
771-
res += ' SELECT DISTINCT %s %s (COUNT(%s) AS %s) WHERE {\n' % (
772-
' '.join([v.n3() for v in vars_]),
773-
edge_var.n3(), node_var.n3(), NODE_VAR_COUNT.n3())
750+
' SELECT %s (SUM (?node_var_count) AS %s) (COUNT(%s) AS %s) ' \
751+
'(MAX(?node_var_count) AS ?max_node_count) WHERE {\n' % (
752+
edge_var.n3(),
753+
NODE_VAR_SUM.n3(),
754+
' && '.join([v.n3() for v in vars_]),
755+
EDGE_VAR_COUNT.n3(), )
756+
res += ' SELECT DISTINCT %s %s (COUNT(%s) AS ?node_var_count) ' \
757+
'WHERE {\n ' % (' '.join([v.n3() for v in vars_]),
758+
edge_var.n3(), node_var.n3(), )
774759
res += self._sparql_values_part(values)
775760

776761
# triples part
@@ -783,14 +768,12 @@ def to_find_edge_var_for_narrow_path_query\
783768
res += ' }\n'\
784769
' }\n'
785770
res += ' GROUP BY %s\n' % edge_var.n3()
786-
res += ' ORDER BY DESC(%s)\n' % EDGE_VAR_COUNT.n3()
787771
res += ' }\n'
788-
res += ' FILTER(%s < %d && %s > %d)\n' % (MAX_NODE_COUNT.n3(),
789-
filter_node_count,
790-
EDGE_VAR_COUNT.n3(),
791-
filter_edge_count)
772+
res += ' FILTER(?max_node_count < %d && %s > %d)\n' \
773+
% (filter_node_count, EDGE_VAR_COUNT.n3(),
774+
filter_edge_count)
792775
res += '}\n'
793-
res += 'ORDER BY DESC(%s)\n' % PRIO_VAR.n3()
776+
res += 'ORDER BY ASC(%s)\n' % NODE_VAR_SUM.n3()
794777
res += 'LIMIT %d' % limit_res
795778
return self._sparql_prefix(res)
796779

0 commit comments

Comments
 (0)