33
33
from rdflib import Variable
34
34
import SPARQLWrapper
35
35
from scoop .futures import map as parallel_map
36
+ import six
36
37
37
38
import logging
38
39
logger = logging .getLogger (__name__ )
@@ -1177,7 +1178,31 @@ def find_graph_pattern_coverage(
1177
1178
return patterns , coverage_counts , gtp_scores
1178
1179
1179
1180
1180
- def predict_target (sparql , timeout , gps , source , parallel = None ):
1181
+ def predict_target_candidates (sparql , timeout , gps , source , parallel = None ):
1182
+ """Uses the gps to predict target candidates for the given source.
1183
+
1184
+ :param sparql: SPARQLWrapper endpoint.
1185
+ :param timeout: Timeout in seconds for each individual query (gp).
1186
+ :param gps: A list of evaluated GraphPattern objects (fitness is used).
1187
+ :param source: source node
1188
+ :param parallel: execute prediction queries in parallel?
1189
+ :return: A list of pairs [(target_candidates, gp)]
1190
+ """
1191
+ if parallel is None :
1192
+ parallel = config .PREDICTION_IN_PARALLEL
1193
+
1194
+ pq = partial (
1195
+ predict_query ,
1196
+ sparql , timeout ,
1197
+ source = source ,
1198
+ )
1199
+ map_ = parallel_map if parallel else map
1200
+ results = map_ (pq , gps )
1201
+ results = zip ([res for _ , res in results ], gps )
1202
+ return results
1203
+
1204
+
1205
+ def fuse_prediction_results (predict_query_results , fusion_methods = None ):
1181
1206
"""Several naive prediction methods for targets given a source.
1182
1207
1183
1208
The naive methods here are used to re-assemble all result lists returned by
@@ -1194,17 +1219,14 @@ def predict_target(sparql, timeout, gps, source, parallel=None):
1194
1219
- 'f_measures_precisions': same as above but scaled with precision
1195
1220
- 'gp_precisions_precisions': same as above but scaled with precision
1196
1221
1197
- :param sparql: SPARQLWrapper endpoint.
1198
- :param timeout: Timeout in seconds for each individual query (gp).
1199
- :param gps: A list of evaluated GraphPattern objects (fitness is used).
1200
- :param source: source node
1201
- :param parallel: execute prediction queries in parallel?
1222
+ :param predict_query_results: a list of [(target_candidates, gp)] as
1223
+ returned by predict_target_candidates().
1224
+ :param fusion_methods: None for all or a list of strings naming the fusion
1225
+ methods to return.
1202
1226
:return: A dict like {method: ranked_res_list}, where ranked_res_list is a
1203
1227
list result list produced by method of (predicted_target, score) pairs
1204
1228
ordered decreasingly by score. For methods see above.
1205
1229
"""
1206
- if parallel is None :
1207
- parallel = config .PREDICTION_IN_PARALLEL
1208
1230
target_occs = Counter ()
1209
1231
scores = Counter ()
1210
1232
f_measures = Counter ()
@@ -1216,15 +1238,8 @@ def predict_target(sparql, timeout, gps, source, parallel=None):
1216
1238
gp_precisions_precisions = Counter ()
1217
1239
1218
1240
# TODO: add cut-off values for methods (will have different recalls then)
1219
- pq = partial (
1220
- predict_query ,
1221
- sparql , timeout ,
1222
- source = source ,
1223
- )
1224
- map_ = parallel_map if parallel else map
1225
- results = map_ (pq , gps )
1226
- results = zip ([res for _ , res in results ], gps )
1227
- for res , gp in results :
1241
+
1242
+ for res , gp in predict_query_results :
1228
1243
score = gp .fitness .values .score
1229
1244
fm = gp .fitness .values .f_measure
1230
1245
gp_precision = 1
@@ -1256,9 +1271,25 @@ def predict_target(sparql, timeout, gps, source, parallel=None):
1256
1271
('f_measures_precisions' , f_measures_precisions .most_common ()),
1257
1272
('gp_precisions_precisions' , gp_precisions_precisions .most_common ()),
1258
1273
])
1274
+ if fusion_methods :
1275
+ # TODO: could improve by not actually calculating them
1276
+ for k in res .keys ():
1277
+ if k not in fusion_methods :
1278
+ del res [k ]
1259
1279
return res
1260
1280
1261
1281
1282
+ def predict_target (
1283
+ sparql , timeout , gps , source ,
1284
+ parallel = None , fusion_methods = None
1285
+ ):
1286
+ """Predict candidates and fuse the results."""
1287
+ return fuse_prediction_results (
1288
+ predict_target_candidates (sparql , timeout , gps , source , parallel ),
1289
+ fusion_methods
1290
+ )
1291
+
1292
+
1262
1293
def find_in_prediction (prediction , target ):
1263
1294
try :
1264
1295
targets , scores = zip (* prediction )
@@ -1267,12 +1298,25 @@ def find_in_prediction(prediction, target):
1267
1298
return - 1
1268
1299
1269
1300
1301
+ def print_prediction_results (method , res , target = None , idx = None ):
1302
+ assert not ((target is None ) ^ (idx is None )), \
1303
+ "target and idx should both be None or neither"
1304
+ print (
1305
+ ' Top 10 predictions (method: %s)%s' % (
1306
+ method , (", target at idx: %d" % idx ) if idx is not None else '' ))
1307
+ for t , score in res [:10 ]:
1308
+ print (
1309
+ ' ' + ('->' if t == target else ' ' ) +
1310
+ '%s (%.3f)' % (t .n3 (), score )
1311
+ )
1312
+
1313
+
1270
1314
def evaluate_prediction (sparql , gps , predict_list ):
1271
1315
recall = 0
1272
1316
method_idxs = defaultdict (list )
1273
1317
res_lens = []
1274
1318
timeout = calibrate_query_timeout (sparql )
1275
- for i , (source , target ) in enumerate (predict_list ):
1319
+ for i , (source , target ) in enumerate (predict_list , 1 ):
1276
1320
print ('%d/%d: predicting target for %s (ground truth: %s):' % (
1277
1321
i , len (predict_list ), source .n3 (), target .n3 ()))
1278
1322
method_res = predict_target (sparql , timeout , gps , source )
@@ -1290,14 +1334,7 @@ def evaluate_prediction(sparql, gps, predict_list):
1290
1334
print (' result list length: %d' % n )
1291
1335
method_idxs [method ].append (idx )
1292
1336
1293
- print (
1294
- ' Top 10 predictions (method: %s), target at idx: %d' % (
1295
- method , idx ))
1296
- for t , score in res [:10 ]:
1297
- print (
1298
- ' ' + ('->' if t == target else ' ' ) +
1299
- '%s (%.3f)' % (t .n3 (), score )
1300
- )
1337
+ print_prediction_results (method , res , target , idx )
1301
1338
1302
1339
recall /= len (predict_list )
1303
1340
print ("Prediction list: %s" % predict_list )
@@ -1462,7 +1499,7 @@ def main(
1462
1499
sys .stderr .flush ()
1463
1500
1464
1501
1465
- if predict :
1502
+ if predict and predict != 'manual' :
1466
1503
assert predict in ('train_set' , 'test_set' )
1467
1504
predict_list = assocs_train if predict == 'train_set' else assocs_test
1468
1505
print ('\n \n \n starting prediction on %s' % predict )
@@ -1475,4 +1512,18 @@ def main(
1475
1512
main_end = datetime .utcnow ()
1476
1513
logging .info ('Overall execution took: %s' , main_end - main_start )
1477
1514
1478
- # TODO: make continuous prediction mode
1515
+ if predict == 'manual' :
1516
+ timeout = calibrate_query_timeout (sparql )
1517
+ sys .stdout .flush ()
1518
+ sys .stderr .flush ()
1519
+
1520
+ while True :
1521
+ s = six .moves .input (
1522
+ '\n \n Enter a DBpedia resource as source:\n '
1523
+ '> http://dbpedia.org/resource/'
1524
+ )
1525
+ source = URIRef ('http://dbpedia.org/resource/' + s )
1526
+
1527
+ method_res = predict_target (sparql , timeout , gps , source )
1528
+ for method , res in method_res .items ():
1529
+ print_prediction_results (method , res )
0 commit comments