Skip to content

Commit 87ddb07

Browse files
committed
added a manual prediction mode (and some refactoring for that) for live debugging
1 parent 19c49b1 commit 87ddb07

File tree

2 files changed

+80
-29
lines changed

2 files changed

+80
-29
lines changed

gp_learner.py

Lines changed: 79 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from rdflib import Variable
3434
import SPARQLWrapper
3535
from scoop.futures import map as parallel_map
36+
import six
3637

3738
import logging
3839
logger = logging.getLogger(__name__)
@@ -1177,7 +1178,31 @@ def find_graph_pattern_coverage(
11771178
return patterns, coverage_counts, gtp_scores
11781179

11791180

1180-
def predict_target(sparql, timeout, gps, source, parallel=None):
1181+
def predict_target_candidates(sparql, timeout, gps, source, parallel=None):
1182+
"""Uses the gps to predict target candidates for the given source.
1183+
1184+
:param sparql: SPARQLWrapper endpoint.
1185+
:param timeout: Timeout in seconds for each individual query (gp).
1186+
:param gps: A list of evaluated GraphPattern objects (fitness is used).
1187+
:param source: source node
1188+
:param parallel: execute prediction queries in parallel?
1189+
:return: A list of pairs [(target_candidates, gp)]
1190+
"""
1191+
if parallel is None:
1192+
parallel = config.PREDICTION_IN_PARALLEL
1193+
1194+
pq = partial(
1195+
predict_query,
1196+
sparql, timeout,
1197+
source=source,
1198+
)
1199+
map_ = parallel_map if parallel else map
1200+
results = map_(pq, gps)
1201+
results = zip([res for _, res in results], gps)
1202+
return results
1203+
1204+
1205+
def fuse_prediction_results(predict_query_results, fusion_methods=None):
11811206
"""Several naive prediction methods for targets given a source.
11821207
11831208
The naive methods here are used to re-assemble all result lists returned by
@@ -1194,17 +1219,14 @@ def predict_target(sparql, timeout, gps, source, parallel=None):
11941219
- 'f_measures_precisions': same as above but scaled with precision
11951220
- 'gp_precisions_precisions': same as above but scaled with precision
11961221
1197-
:param sparql: SPARQLWrapper endpoint.
1198-
:param timeout: Timeout in seconds for each individual query (gp).
1199-
:param gps: A list of evaluated GraphPattern objects (fitness is used).
1200-
:param source: source node
1201-
:param parallel: execute prediction queries in parallel?
1222+
:param predict_query_results: a list of [(target_candidates, gp)] as
1223+
returned by predict_target_candidates().
1224+
:param fusion_methods: None for all or a list of strings naming the fusion
1225+
methods to return.
12021226
:return: A dict like {method: ranked_res_list}, where ranked_res_list is a
12031227
list result list produced by method of (predicted_target, score) pairs
12041228
ordered decreasingly by score. For methods see above.
12051229
"""
1206-
if parallel is None:
1207-
parallel = config.PREDICTION_IN_PARALLEL
12081230
target_occs = Counter()
12091231
scores = Counter()
12101232
f_measures = Counter()
@@ -1216,15 +1238,8 @@ def predict_target(sparql, timeout, gps, source, parallel=None):
12161238
gp_precisions_precisions = Counter()
12171239

12181240
# TODO: add cut-off values for methods (will have different recalls then)
1219-
pq = partial(
1220-
predict_query,
1221-
sparql, timeout,
1222-
source=source,
1223-
)
1224-
map_ = parallel_map if parallel else map
1225-
results = map_(pq, gps)
1226-
results = zip([res for _, res in results], gps)
1227-
for res, gp in results:
1241+
1242+
for res, gp in predict_query_results:
12281243
score = gp.fitness.values.score
12291244
fm = gp.fitness.values.f_measure
12301245
gp_precision = 1
@@ -1256,9 +1271,25 @@ def predict_target(sparql, timeout, gps, source, parallel=None):
12561271
('f_measures_precisions', f_measures_precisions.most_common()),
12571272
('gp_precisions_precisions', gp_precisions_precisions.most_common()),
12581273
])
1274+
if fusion_methods:
1275+
# TODO: could improve by not actually calculating them
1276+
for k in res.keys():
1277+
if k not in fusion_methods:
1278+
del res[k]
12591279
return res
12601280

12611281

1282+
def predict_target(
1283+
sparql, timeout, gps, source,
1284+
parallel=None, fusion_methods=None
1285+
):
1286+
"""Predict candidates and fuse the results."""
1287+
return fuse_prediction_results(
1288+
predict_target_candidates(sparql, timeout, gps, source, parallel),
1289+
fusion_methods
1290+
)
1291+
1292+
12621293
def find_in_prediction(prediction, target):
12631294
try:
12641295
targets, scores = zip(*prediction)
@@ -1267,12 +1298,25 @@ def find_in_prediction(prediction, target):
12671298
return -1
12681299

12691300

1301+
def print_prediction_results(method, res, target=None, idx=None):
1302+
assert not ((target is None) ^ (idx is None)), \
1303+
"target and idx should both be None or neither"
1304+
print(
1305+
' Top 10 predictions (method: %s)%s' % (
1306+
method, (", target at idx: %d" % idx) if idx is not None else ''))
1307+
for t, score in res[:10]:
1308+
print(
1309+
' ' + ('->' if t == target else ' ') +
1310+
'%s (%.3f)' % (t.n3(), score)
1311+
)
1312+
1313+
12701314
def evaluate_prediction(sparql, gps, predict_list):
12711315
recall = 0
12721316
method_idxs = defaultdict(list)
12731317
res_lens = []
12741318
timeout = calibrate_query_timeout(sparql)
1275-
for i, (source, target) in enumerate(predict_list):
1319+
for i, (source, target) in enumerate(predict_list, 1):
12761320
print('%d/%d: predicting target for %s (ground truth: %s):' % (
12771321
i, len(predict_list), source.n3(), target.n3()))
12781322
method_res = predict_target(sparql, timeout, gps, source)
@@ -1290,14 +1334,7 @@ def evaluate_prediction(sparql, gps, predict_list):
12901334
print(' result list length: %d' % n)
12911335
method_idxs[method].append(idx)
12921336

1293-
print(
1294-
' Top 10 predictions (method: %s), target at idx: %d' % (
1295-
method, idx))
1296-
for t, score in res[:10]:
1297-
print(
1298-
' ' + ('->' if t == target else ' ') +
1299-
'%s (%.3f)' % (t.n3(), score)
1300-
)
1337+
print_prediction_results(method, res, target, idx)
13011338

13021339
recall /= len(predict_list)
13031340
print("Prediction list: %s" % predict_list)
@@ -1462,7 +1499,7 @@ def main(
14621499
sys.stderr.flush()
14631500

14641501

1465-
if predict:
1502+
if predict and predict != 'manual':
14661503
assert predict in ('train_set', 'test_set')
14671504
predict_list = assocs_train if predict == 'train_set' else assocs_test
14681505
print('\n\n\nstarting prediction on %s' % predict)
@@ -1475,4 +1512,18 @@ def main(
14751512
main_end = datetime.utcnow()
14761513
logging.info('Overall execution took: %s', main_end - main_start)
14771514

1478-
# TODO: make continuous prediction mode
1515+
if predict == 'manual':
1516+
timeout = calibrate_query_timeout(sparql)
1517+
sys.stdout.flush()
1518+
sys.stderr.flush()
1519+
1520+
while True:
1521+
s = six.moves.input(
1522+
'\n\nEnter a DBpedia resource as source:\n'
1523+
'> http://dbpedia.org/resource/'
1524+
)
1525+
source = URIRef('http://dbpedia.org/resource/' + s)
1526+
1527+
method_res = predict_target(sparql, timeout, gps, source)
1528+
for method, res in method_res.items():
1529+
print_prediction_results(method, res)

run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@
133133
"disable evaluation set to ''.",
134134
action="store",
135135
type=str,
136-
choices=("test_set", "train_set", ""),
136+
choices=("test_set", "train_set", "manual", ""),
137137
default="",
138138
)
139139

0 commit comments

Comments
 (0)