Skip to content

Commit 907f581

Browse files
committed
eval pattern generator now allows excluding source_target_edges
1 parent 443799b commit 907f581

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

eval.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def patterns(
176176
loops=True,
177177
node_edge_joint=True,
178178
p_connected=True,
179+
source_target_edges=True,
179180
exclude_isomorphic=True,
180181
count_candidates_only=False,
181182
):
@@ -192,7 +193,16 @@ def patterns(
192193
node_edge_joint=node_edge_joint,
193194
p_connected=p_connected,
194195
)):
195-
numbers = sorted(set([v for t in num_pat for v in t]))
196+
flat_num_pat = [v for t in num_pat for v in t]
197+
all_numbers = set(flat_num_pat)
198+
199+
if source_target_edges:
200+
all_numbers = sorted(all_numbers)
201+
numbers = all_numbers
202+
else:
203+
numbers = sorted(all_numbers - set(flat_num_pat[1::3]))
204+
all_numbers = sorted(all_numbers)
205+
196206
# var_map = {i: '?v%d' % i for i in numbers}
197207
# pattern = GraphPattern(
198208
# tuple([tuple([var_map[i] for i in t]) for t in numerical_repr]))
@@ -210,7 +220,9 @@ def patterns(
210220

211221
for s, t in permutations(numbers, 2):
212222
pid += 1
213-
leftover_numbers = [n for n in numbers if n != s and n != t]
223+
# source and target are mapped to numbers s and t
224+
# re-enumerate the leftover numbers to close "holes"
225+
leftover_numbers = [n for n in all_numbers if n != s and n != t]
214226
var_map = {n: Variable('v%d' % i)
215227
for i, n in enumerate(leftover_numbers)}
216228
var_map[s] = SOURCE_VAR
@@ -245,6 +257,7 @@ def pattern_generator(
245257
loops=True,
246258
node_edge_joint=True,
247259
p_connected=True,
260+
source_target_edges=True,
248261
exclude_isomorphic=True,
249262
):
250263
canonicalized_patterns = {}
@@ -344,6 +357,7 @@ def main():
344357
loops=False,
345358
node_edge_joint=False,
346359
p_connected=False,
360+
source_target_edges=False,
347361
exclude_isomorphic=canonical,
348362
count_candidates_only=False,
349363
)):
@@ -362,8 +376,8 @@ def main():
362376
if len(mod_gp) == length:
363377
cmod_pg = canonicalize(mod_gp)
364378
assert cmod_pg in _patterns, \
365-
'mod_gp: %r\ncanon: %r\n_patterns: %r' % (
366-
mod_gp, cmod_pg, _patterns
379+
'gp: %smod_gp: %scanon: %s_patterns: %r...' % (
380+
gp, mod_gp, cmod_pg, list(_patterns)[:20]
367381
)
368382

369383

0 commit comments

Comments
 (0)