Skip to content

Commit 8065067

Browse files
committed
remaining_gain_sample_gtps n renamed to more appropriate max_n
1 parent 46042f4 commit 8065067

File tree

4 files changed

+24
-23
lines changed

4 files changed

+24
-23
lines changed

gp_learner.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -522,9 +522,9 @@ def mutate_fix_var(
522522
timeout,
523523
gtp_scores,
524524
child,
525-
gtp_sample_n=config.MUTPB_FV_RGTP_SAMPLE_N,
525+
gtp_sample_max_n=config.MUTPB_FV_RGTP_SAMPLE_N,
526526
rand_var=None,
527-
sample_n=config.MUTPB_FV_SAMPLE_MAXN,
527+
sample_max_n=config.MUTPB_FV_SAMPLE_MAXN,
528528
limit=config.MUTPB_FV_QUERY_LIMIT,
529529
):
530530
"""Chooses a random variable from the pattern(node or edge).
@@ -537,10 +537,11 @@ def mutate_fix_var(
537537
# The further we get, the less gtps are remaining. Sampling too many (all)
538538
# of them might hurt as common substitutions (> limit ones) which are dead
539539
# ends could cover less common ones that could actually help
540-
gtp_sample_n = min(gtp_sample_n, int(gtp_scores.remaining_gain))
541-
gtp_sample_n = random.randint(1, gtp_sample_n)
540+
gtp_sample_max_n = min(gtp_sample_max_n, int(gtp_scores.remaining_gain))
541+
gtp_sample_max_n = random.randint(1, gtp_sample_max_n)
542542

543-
ground_truth_pairs = gtp_scores.remaining_gain_sample_gtps(n=gtp_sample_n)
543+
ground_truth_pairs = gtp_scores.remaining_gain_sample_gtps(
544+
max_n=gtp_sample_max_n)
544545
rand_vars = child.vars_in_graph - {SOURCE_VAR, TARGET_VAR}
545546
if len(rand_vars) < 1:
546547
return [child]
@@ -561,13 +562,13 @@ def mutate_fix_var(
561562
return [child]
562563
# randomly pick n of the substitutions with a prob ~ to their counts
563564
items, counts = zip(*substitution_counts.most_common())
564-
substs = sample_from_list(items, counts, sample_n)
565+
substs = sample_from_list(items, counts, sample_max_n)
565566
logger.info(
566567
'fixed variable %s in %sto:\n %s\n<%d out of:\n%s\n',
567568
rand_var.n3(),
568569
child,
569570
'\n '.join([subst.n3() for subst in substs]),
570-
sample_n,
571+
sample_max_n,
571572
'\n'.join([' %d: %s' % (c, v.n3())
572573
for v, c in substitution_counts.most_common()]),
573574
)

gtp_scores.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,19 @@ def update_with_gps(self, gps):
5252
self.gtp_max_precisions[gtp] = precision
5353
return precision_gain
5454

55-
def remaining_gain_sample_gtps(self, n=None):
55+
def remaining_gain_sample_gtps(self, max_n=None):
5656
"""Sample ground truth pairs according to remaining gains.
5757
58-
This method draws up to n ground truth pairs using their remaining gains
59-
as sample probabilities. If less than n probabilities are > 0 it draws
60-
less gtps.
58+
This method draws up to max_n ground truth pairs using their remaining
59+
gains as sample probabilities. GTPs with remaining gain of 0 are never
60+
returned, so if less than n probabilities are > 0 it draws less gtps.
6161
62-
:param n: Up to n items to sample.
62+
:param max_n: Up to n items to sample.
6363
:return: list of ground truth pairs sampled according to their remaining
6464
gains in gtp_scores with max length of n.
6565
"""
6666
gtps, gains = zip(*self.get_remaining_gains().items())
67-
return sample_from_list(gtps, gains, n)
67+
return sample_from_list(gtps, gains, max_n)
6868

6969
def __sub__(self, other):
7070
if not isinstance(other, GTPScores):

tests/test_gp_learner_offline.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,13 @@ def test_simplify_pattern():
169169
assert res == gp, 'not simplified:\n%s' % res.to_sparql_select_query()
170170

171171
# counter example of an advanced but restricting pattern:
172-
gp = gp + [
172+
gp += [
173173
(SOURCE_VAR, Variable('v3'), Variable('v4')),
174174
(Variable('v5'), Variable('v6'), Variable('v4')),
175175
(Variable('v4'), Variable('v7'), Variable('v8')),
176176
(TARGET_VAR, Variable('v3'), SOURCE_VAR),
177177
(dbp['City'], Variable('v6'), dbp['Country']),
178178
(dbp['Country'], Variable('v8'), dbp['City']),
179-
180179
]
181180
res = mutate_simplify_pattern(gp)
182181
assert res == gp, 'was simplified (bad):\n%s' % res.to_sparql_select_query()
@@ -221,22 +220,22 @@ def test_simplify_pattern():
221220

222221
def test_remaining_gain_sample_gtps():
223222
n = len(ground_truth_pairs)
224-
gtps = sorted(gtp_scores.remaining_gain_sample_gtps(n=n))
223+
gtps = sorted(gtp_scores.remaining_gain_sample_gtps(max_n=n))
225224
assert len(gtps) == n
226225
# if we draw everything the results should always be everything
227-
assert gtps == sorted(gtp_scores.remaining_gain_sample_gtps(n=n))
226+
assert gtps == sorted(gtp_scores.remaining_gain_sample_gtps(max_n=n))
228227
# if we don't draw everything it's quite unlikely we get the same result
229-
gtps = gtp_scores.remaining_gain_sample_gtps(n=5)
228+
gtps = gtp_scores.remaining_gain_sample_gtps(max_n=5)
230229
assert len(gtps) == 5
231-
assert gtps != gtp_scores.remaining_gain_sample_gtps(n=5)
230+
assert gtps != gtp_scores.remaining_gain_sample_gtps(max_n=5)
232231

233232
# make sure we never get items that are fully covered already
234233
gtp_scores.gtp_max_precisions[ground_truth_pairs[0]] = 1
235234
c = Counter()
236235
k = 100
237236
n = 128
238237
for i in range(k):
239-
c.update(gtp_scores.remaining_gain_sample_gtps(n=n))
238+
c.update(gtp_scores.remaining_gain_sample_gtps(max_n=n))
240239
assert ground_truth_pairs[0] not in c
241240
assert sum(c.values()) == k * n
242241
# count how many aren't in gtps
@@ -260,7 +259,7 @@ def test_remaining_gain_sample_gtps():
260259
assert gtpe_scores.remaining_gain == 1
261260
c = Counter()
262261
for i in range(100):
263-
c.update(gtpe_scores.remaining_gain_sample_gtps(n=1))
262+
c.update(gtpe_scores.remaining_gain_sample_gtps(max_n=1))
264263
assert len(c) == 2
265264
assert sum(c.values()) == 100
266265
assert (binom.pmf(c[high_prob], 100, .9) > 0.001 and

utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,9 @@ def sample_from_list(l, probs, max_n=None):
247247
"""Sample list according to probs.
248248
249249
This method draws up to max_n items from l using the given list of probs as
250-
sample probabilities. max_n defaults to len(l) if not specified. If less
251-
than max_n probabilities are > 0 only those items are returned.
250+
sample probabilities. max_n defaults to len(l) if not specified. Items with
251+
probability 0 are never sampled, so if less than max_n probabilities are > 0
252+
only those items are returned.
252253
253254
:param l: list from which to draw items.
254255
:param probs: List of probabilities to draw items. Normalized by sum(probs).

0 commit comments

Comments
 (0)