Skip to content

Commit 443799b

Browse files
committed
numerical pattern generator can now also exclude loops, node_edge_joint and p_connected patterns
1 parent 28c6c18 commit 443799b

File tree

1 file changed

+81
-31
lines changed

1 file changed

+81
-31
lines changed

eval.py

Lines changed: 81 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
logger.info('init')
2727

2828

29-
DEBUG = False
29+
DEBUG = True
3030
HOLE = sys.maxint # placeholder for holes in partial patterns
3131

3232
# debug logging in this module is actually quite expensive (> 30 % of time). In
@@ -40,6 +40,9 @@ def quick_skip_debug_log(*args, **kwds):
4040

4141
def numerical_patterns(
4242
length,
43+
loops=True,
44+
node_edge_joint=True,
45+
p_connected=True,
4346
_partial_pattern=None,
4447
_pos=None,
4548
_var=1,
@@ -91,28 +94,49 @@ def numerical_patterns(
9194
# exclude multiple equivalent triples
9295
return
9396

94-
if i >= 1 and j == 2:
95-
# we just completed a triple, check that it's connected
96-
t = _partial_pattern[i]
97-
for pt in _partial_pattern[:i]:
98-
if t[0] in pt or t[1] in pt or t[2] in pt:
99-
break
100-
else:
101-
# we're not connected, early terminate this
102-
# This is safe as a later triple can't reconnect us anymore without
103-
# an isomorphic, lower enumeration that would've been encountered
104-
# before:
105-
# say we have
106-
# abc xyz uvw
107-
# with xyz not being connected yet and uvw or any later part
108-
# connecting xyz back to abc. We can just use a breadth first search
109-
# from abc via those connecting triples and re-label all encountered
110-
# vars by breadth first search encountering. That re-labeling is
111-
# guaranteed to forward connect and it will generate a smaller
112-
# labelling than the current one.
97+
98+
# check if nodes and edges are disjoint
99+
if not node_edge_joint:
100+
flat_pp = [v for t in _partial_pattern for v in t]
101+
end = i*3 + j + 1 # end including last var
102+
nodes = set(flat_pp[0:end:3] + flat_pp[2:end:3])
103+
edges = set(flat_pp[1:end:3])
104+
if nodes & edges:
105+
logger.debug(
106+
'excluded node-edge-joined: %s', _partial_pattern[:i+1])
113107
return
114108

115-
if i >= length - 1 and j >= 2:
109+
if j == 2: # we just completed a triple
110+
# check for loops if necessary
111+
if not loops:
112+
s, _, o = _partial_pattern[i]
113+
if s == o:
114+
logger.debug('excluded loop: %s', _partial_pattern[:i+1])
115+
return
116+
117+
if i >= 1: # we're in a follow-up triple (excluding first)
118+
# check that it's connected
119+
s, p, o = _partial_pattern[i]
120+
for pt in _partial_pattern[:i]:
121+
# loop over previous triples and check if current is connected
122+
if s in pt or o in pt or (p_connected and p in pt):
123+
break
124+
else:
125+
# we're not connected, early terminate this
126+
# This is safe as a later triple can't reconnect us anymore
127+
# without an isomorphic, lower enumeration that would've been
128+
# encountered before:
129+
# say we have
130+
# abc xyz uvw
131+
# with xyz not being connected yet and uvw or any later part
132+
# connecting xyz back to abc. We can just use a breadth first
133+
# search from abc via those connecting triples and re-label all
134+
# encountered vars by breadth first search encountering. That
135+
# re-labeling is guaranteed to forward connect and it will
136+
# generate a smaller labelling than the current one.
137+
return
138+
139+
if i == length - 1 and j == 2:
116140
# we're at the end of the pattern
117141
yield _partial_pattern
118142
else:
@@ -137,6 +161,9 @@ def numerical_patterns(
137161
for v in range(_star_var, _end_var + 1):
138162
for pattern in numerical_patterns(
139163
length,
164+
loops=loops,
165+
node_edge_joint=node_edge_joint,
166+
p_connected=p_connected,
140167
_partial_pattern=_partial_pattern,
141168
_pos=(i, j),
142169
_var=v
@@ -146,6 +173,9 @@ def numerical_patterns(
146173

147174
def patterns(
148175
length,
176+
loops=True,
177+
node_edge_joint=True,
178+
p_connected=True,
149179
exclude_isomorphic=True,
150180
count_candidates_only=False,
151181
):
@@ -156,7 +186,12 @@ def patterns(
156186
canonicalized_patterns = {}
157187

158188
pid = -1
159-
for c, num_pat in enumerate(numerical_patterns(length)):
189+
for c, num_pat in enumerate(numerical_patterns(
190+
length,
191+
loops=loops,
192+
node_edge_joint=node_edge_joint,
193+
p_connected=p_connected,
194+
)):
160195
numbers = sorted(set([v for t in num_pat for v in t]))
161196
# var_map = {i: '?v%d' % i for i in numbers}
162197
# pattern = GraphPattern(
@@ -291,6 +326,7 @@ def pattern_generator(
291326

292327
def main():
293328
length = 1
329+
canonical = True
294330
# len | pcon | nej | pcon, nej | candidates | candidates |
295331
# | | | (canonical) | (old method) | (numerical) |
296332
# ----+------+-----+--------------+----------------+-------------+
@@ -301,20 +337,34 @@ def main():
301337
# 5 | | | | 34549552710596 | 3461471628 |
302338

303339
gen_patterns = []
340+
n = 0
304341
i = 0
305-
for n, (i, pattern) in enumerate(patterns(length, False, True)):
342+
for n, (i, pattern) in enumerate(patterns(
343+
length,
344+
loops=False,
345+
node_edge_joint=False,
346+
p_connected=False,
347+
exclude_isomorphic=canonical,
348+
count_candidates_only=False,
349+
)):
306350
print('%d: Pattern id %d: %s' % (n, i, pattern))
307351
gen_patterns.append((i, pattern))
308-
print(i)
352+
print('Number of pattern candidates: %d' % i)
353+
print('Number of patterns: %d' % n)
309354
_patterns = set(gp for pid, gp in gen_patterns[:-1])
310355

311-
# testing flipped edges
312-
for gp in _patterns:
313-
for i in range(length):
314-
mod_gp = gp.flip_edge(i)
315-
# can happen that flipped edge was there already
316-
if len(mod_gp) == length:
317-
assert canonicalize(mod_gp) in _patterns
356+
# testing flipped edges (only works if we're working with canonicals)
357+
if canonical:
358+
for gp in _patterns:
359+
for i in range(length):
360+
mod_gp = gp.flip_edge(i)
361+
# can happen that flipped edge was there already
362+
if len(mod_gp) == length:
363+
cmod_pg = canonicalize(mod_gp)
364+
assert cmod_pg in _patterns, \
365+
'mod_gp: %r\ncanon: %r\n_patterns: %r' % (
366+
mod_gp, cmod_pg, _patterns
367+
)
318368

319369

320370
if __name__ == '__main__':

0 commit comments

Comments
 (0)