Skip to content

Commit 6407965

Browse files
committed
eval old pattern generator improved wrt. node_edge_joints
1 parent 3e1ccdc commit 6407965

File tree

1 file changed

+41
-15
lines changed

1 file changed

+41
-15
lines changed

eval.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ def patterns(
191191
"""Takes a numerical pattern and generates actual patterns from it."""
192192
assert not count_candidates_only or not exclude_isomorphic, \
193193
'count_candidates_only cannot be used with isomorphism check'
194+
assert not source_target_edges or node_edge_joint, \
195+
'source_target_edges cannot be used without node_edge_joint'
194196

195197
canonicalized_patterns = {}
196198

@@ -274,25 +276,40 @@ def pattern_generator(
274276
p_only_connected=True,
275277
source_target_edges=True,
276278
exclude_isomorphic=True,
279+
count_candidates_only=False,
277280
):
281+
assert not source_target_edges or node_edge_joint, \
282+
'source_target_edges cannot be used without node_edge_joint'
278283
canonicalized_patterns = {}
279284

280-
# To be connected there are max 3 + 2 + 2 + 2 + ... vars for the triples.
281-
# The first can be 3 different ones (including ?source and ?target, then
282-
# in each of the following triples at least one var has to be an old one
283-
possible_vars = [Variable('v%d' % i) for i in range((2 * length) - 1)]
284-
possible_vars += [SOURCE_VAR, TARGET_VAR]
285+
if node_edge_joint:
286+
# To be connected there are max 3 + 2 + 2 + 2 + ... vars for triples.
287+
# The first can be 3 different ones (including ?source and ?target, then
288+
# in each of the following triples at least one var has to be an old one
289+
possible_vars = [Variable('v%d' % i) for i in range((2 * length) - 1)]
290+
possible_nodes = possible_vars + [SOURCE_VAR, TARGET_VAR]
291+
if source_target_edges:
292+
possible_edges = possible_nodes
293+
else:
294+
possible_edges = possible_vars
295+
else:
296+
possible_var_nodes = [Variable('n%d' % i) for i in range(length - 1)]
297+
possible_nodes = possible_var_nodes + [SOURCE_VAR, TARGET_VAR]
298+
possible_edges = [Variable('e%d' % i) for i in range(length)]
285299

286300
possible_triples = [
287301
(s, p, o)
288-
for s in possible_vars
289-
for p in possible_vars
290-
for o in possible_vars
302+
for s in possible_nodes
303+
for p in possible_edges
304+
for o in possible_nodes
291305
]
292306

293307
n_patterns = binom(len(possible_triples), length)
294308
logger.info(
295309
'generating %d possible patterns of length %d', n_patterns, length)
310+
if count_candidates_only:
311+
yield (n_patterns, None)
312+
return
296313

297314
i = 0
298315
pid = 0
@@ -304,10 +321,19 @@ def pattern_generator(
304321
logger.debug(
305322
'excluded %d: source or target missing: %s', pid, gp)
306323
continue
324+
nodes = sorted(gp.nodes - {SOURCE_VAR, TARGET_VAR})
325+
edges = sorted(gp.edges - {SOURCE_VAR, TARGET_VAR})
307326
vars_ = sorted(gp.vars_in_graph - {SOURCE_VAR, TARGET_VAR})
308327

309-
# check there are no skipped nodes, e.g., link to n2 picked but no n1
310-
if vars_ != possible_vars[:len(vars_)]:
328+
# check there are no skipped variables (nodes or edges)
329+
# noinspection PyUnboundLocalVariable
330+
if (
331+
(node_edge_joint and vars_ != possible_vars[:len(vars_)]) or
332+
(not node_edge_joint and (
333+
nodes != possible_var_nodes[:len(nodes)] or
334+
edges != possible_edges[:len(edges)]
335+
))
336+
):
311337
logger.debug('excluded %d: skipped var: %s', pid, gp)
312338
continue
313339

@@ -365,11 +391,11 @@ def main():
365391
# len | typical | candidates | candidates |
366392
# | (canonical) | (old method) | (numerical) |
367393
# ----+-------------+----------------+-------------+
368-
# 1 | 2 | 27 | 2 |
369-
# 2 | 28 | 7750 | 54 |
370-
# 3 | 486 | 6666891 | 1614 |
371-
# 4 | 10374 | 11671285626 | 59654 |
372-
# 5 | | 34549552710596 | 2707960 |
394+
# 1 | 2 | 4 | 2 |
395+
# 2 | 28 | 153 | 54 |
396+
# 3 | 486 | 17296 | 1614 |
397+
# 4 | 10374 | 3921225 | 59654 |
398+
# 5 | | 1488847536 | 2707960 |
373399

374400
# typical above means none of (loops, nej, pcon, source_target_edges)
375401

0 commit comments

Comments
 (0)