@@ -911,15 +911,19 @@ def generate_init_population(
911
911
return fit_population
912
912
913
913
914
- def generation_step_callback (toolbox , ngen , population ):
914
+ def generation_step_callback (
915
+ run , gtp_scores , user_callback_per_generation , ngen , population
916
+ ):
915
917
"""Called after each generation step cycle in train().
916
918
917
- :param toolbox: toolbox of the evolutionary algorithm.
919
+ :param run: number of the current run
920
+ :param gtp_scores: gtp_scores as of start of this run
921
+ :param user_callback_per_generation: a user provided callback that is called
922
+ after each training generation. It not None called like this:
923
+ user_callback_per_generation(run, gtp_scores, ngen, population)
918
924
:param ngen: the number of the current generation.
919
925
:param population: the current population after generation ngen.
920
926
"""
921
- run = toolbox .get_run ()
922
- gtp_scores = toolbox .get_gtp_scores ()
923
927
top_counter = print_population (run , ngen , population )
924
928
top_gps = sorted (
925
929
top_counter .keys (), key = attrgetter ("fitness" ), reverse = True
@@ -929,15 +933,18 @@ def generation_step_callback(toolbox, ngen, population):
929
933
save_population (
930
934
run , ngen , top_gps , generation_gtp_scores
931
935
)
936
+ if user_callback_per_generation :
937
+ # user provided callback
938
+ user_callback_per_generation (run , gtp_scores , ngen , population )
932
939
933
940
934
941
def find_graph_patterns (
935
- sparql , run , gtp_scores ):
942
+ sparql , run , gtp_scores ,
943
+ user_callback_per_generation = None ,
944
+ ):
936
945
timeout = calibrate_query_timeout (sparql )
937
946
938
947
toolbox = deap .base .Toolbox ()
939
- toolbox .register ("get_run" , lambda : run )
940
- toolbox .register ("get_gtp_scores" , lambda : gtp_scores )
941
948
942
949
toolbox .register (
943
950
"mate" , mate
@@ -952,7 +959,10 @@ def find_graph_patterns(
952
959
)
953
960
toolbox .register (
954
961
"evaluate" , evaluate , sparql , timeout , gtp_scores )
955
- toolbox .register ("generation_step_callback" , generation_step_callback )
962
+ toolbox .register (
963
+ "generation_step_callback" ,
964
+ generation_step_callback , run , gtp_scores , user_callback_per_generation
965
+ )
956
966
957
967
958
968
population = generate_init_population (
@@ -985,11 +995,15 @@ def _find_graph_pattern_coverage_run(
985
995
coverage_counts ,
986
996
gtp_scores ,
987
997
patterns ,
998
+ user_callback_per_generation = None ,
999
+ user_callback_per_run = None ,
988
1000
):
989
1001
min_fitness = calc_min_fitness (gtp_scores , min_score )
990
1002
991
1003
ngen , res_pop , hall_of_fame , toolbox = find_graph_patterns (
992
- sparql , run , gtp_scores )
1004
+ sparql , run , gtp_scores ,
1005
+ user_callback_per_generation = user_callback_per_generation ,
1006
+ )
993
1007
994
1008
# TODO: coverage patterns should be chosen based on similarity
995
1009
new_best_patterns = []
@@ -1085,6 +1099,11 @@ def _find_graph_pattern_coverage_run(
1085
1099
)
1086
1100
set_symlink (fp , config .SYMLINK_CURRENT_RES_RUN )
1087
1101
1102
+ if user_callback_per_run :
1103
+ user_callback_per_run (
1104
+ run , gtp_scores , new_best_patterns , coverage_counts
1105
+ )
1106
+
1088
1107
return new_best_patterns , coverage_counts , gtp_scores
1089
1108
1090
1109
@@ -1096,6 +1115,8 @@ def find_graph_pattern_coverage(
1096
1115
max_runs = config .NRUNS ,
1097
1116
runs_no_improvement = config .NRUNS_NO_IMPROVEMENT ,
1098
1117
error_retries = config .ERROR_RETRIES ,
1118
+ user_callback_per_generation = None ,
1119
+ user_callback_per_run = None ,
1099
1120
):
1100
1121
assert isinstance (ground_truth_pairs , tuple )
1101
1122
@@ -1135,6 +1156,8 @@ def find_graph_pattern_coverage(
1135
1156
coverage_counts ,
1136
1157
gtp_scores ,
1137
1158
patterns ,
1159
+ user_callback_per_generation = user_callback_per_generation ,
1160
+ user_callback_per_run = user_callback_per_run ,
1138
1161
)
1139
1162
new_best_patterns , coverage_counts , gtp_scores = res
1140
1163
patterns .update ({pat : run for pat , run in new_best_patterns })
0 commit comments