Skip to content

Commit bdb01c0

Browse files
committed
extended gp learner with args for user callbacks per generation and run
1 parent 87ddb07 commit bdb01c0

File tree

1 file changed

+32
-9
lines changed

1 file changed

+32
-9
lines changed

gp_learner.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -911,15 +911,19 @@ def generate_init_population(
911911
return fit_population
912912

913913

914-
def generation_step_callback(toolbox, ngen, population):
914+
def generation_step_callback(
915+
run, gtp_scores, user_callback_per_generation, ngen, population
916+
):
915917
"""Called after each generation step cycle in train().
916918
917-
:param toolbox: toolbox of the evolutionary algorithm.
919+
:param run: number of the current run
920+
:param gtp_scores: gtp_scores as of start of this run
921+
:param user_callback_per_generation: a user provided callback that is called
922+
after each training generation. It not None called like this:
923+
user_callback_per_generation(run, gtp_scores, ngen, population)
918924
:param ngen: the number of the current generation.
919925
:param population: the current population after generation ngen.
920926
"""
921-
run = toolbox.get_run()
922-
gtp_scores = toolbox.get_gtp_scores()
923927
top_counter = print_population(run, ngen, population)
924928
top_gps = sorted(
925929
top_counter.keys(), key=attrgetter("fitness"), reverse=True
@@ -929,15 +933,18 @@ def generation_step_callback(toolbox, ngen, population):
929933
save_population(
930934
run, ngen, top_gps, generation_gtp_scores
931935
)
936+
if user_callback_per_generation:
937+
# user provided callback
938+
user_callback_per_generation(run, gtp_scores, ngen, population)
932939

933940

934941
def find_graph_patterns(
935-
sparql, run, gtp_scores):
942+
sparql, run, gtp_scores,
943+
user_callback_per_generation=None,
944+
):
936945
timeout = calibrate_query_timeout(sparql)
937946

938947
toolbox = deap.base.Toolbox()
939-
toolbox.register("get_run", lambda: run)
940-
toolbox.register("get_gtp_scores", lambda: gtp_scores)
941948

942949
toolbox.register(
943950
"mate", mate
@@ -952,7 +959,10 @@ def find_graph_patterns(
952959
)
953960
toolbox.register(
954961
"evaluate", evaluate, sparql, timeout, gtp_scores)
955-
toolbox.register("generation_step_callback", generation_step_callback)
962+
toolbox.register(
963+
"generation_step_callback",
964+
generation_step_callback, run, gtp_scores, user_callback_per_generation
965+
)
956966

957967

958968
population = generate_init_population(
@@ -985,11 +995,15 @@ def _find_graph_pattern_coverage_run(
985995
coverage_counts,
986996
gtp_scores,
987997
patterns,
998+
user_callback_per_generation=None,
999+
user_callback_per_run=None,
9881000
):
9891001
min_fitness = calc_min_fitness(gtp_scores, min_score)
9901002

9911003
ngen, res_pop, hall_of_fame, toolbox = find_graph_patterns(
992-
sparql, run, gtp_scores)
1004+
sparql, run, gtp_scores,
1005+
user_callback_per_generation=user_callback_per_generation,
1006+
)
9931007

9941008
# TODO: coverage patterns should be chosen based on similarity
9951009
new_best_patterns = []
@@ -1085,6 +1099,11 @@ def _find_graph_pattern_coverage_run(
10851099
)
10861100
set_symlink(fp, config.SYMLINK_CURRENT_RES_RUN)
10871101

1102+
if user_callback_per_run:
1103+
user_callback_per_run(
1104+
run, gtp_scores, new_best_patterns, coverage_counts
1105+
)
1106+
10881107
return new_best_patterns, coverage_counts, gtp_scores
10891108

10901109

@@ -1096,6 +1115,8 @@ def find_graph_pattern_coverage(
10961115
max_runs=config.NRUNS,
10971116
runs_no_improvement=config.NRUNS_NO_IMPROVEMENT,
10981117
error_retries=config.ERROR_RETRIES,
1118+
user_callback_per_generation=None,
1119+
user_callback_per_run=None,
10991120
):
11001121
assert isinstance(ground_truth_pairs, tuple)
11011122

@@ -1135,6 +1156,8 @@ def find_graph_pattern_coverage(
11351156
coverage_counts,
11361157
gtp_scores,
11371158
patterns,
1159+
user_callback_per_generation=user_callback_per_generation,
1160+
user_callback_per_run=user_callback_per_run,
11381161
)
11391162
new_best_patterns, coverage_counts, gtp_scores = res
11401163
patterns.update({pat: run for pat, run in new_best_patterns})

0 commit comments

Comments
 (0)