@@ -147,6 +147,9 @@ def train(self, patience, num_epochs):
147147 # everything is set up
148148 for e in range (num_epochs ):
149149 # simulate training
150+
151+ # calling reschedule hook
152+ reschedule_hook (model_weights = {}, step = e )
150153 continue
151154 results = {
152155 "test_acc" : 0.5 + 0.3 * np .random .randn (),
@@ -165,6 +168,22 @@ def get_experiment(init_all=False):
165168 return experiment
166169
167170
171+ # This function will be called when the reschedule is triggered.
172+ # It should save the current state of the experiment and return a
173+ # dictionary that may be used to update the configuration upon rescheduling.
174+ # You are responsible for implementing the actual saving/loading of the experiment state
175+ # due to the updated config.
176+ @ex .reschedule_hook
177+ def reschedule_hook (model_weights , step , ** kwargs ):
178+ # Here you would save the current state of the experiment
179+ # and return any necessary configuration updates.
180+
181+ # !!! You will need to call this function regularly from within your training loop
182+ # to check if rescheduling is needed.
183+ # Pass everything you need to store your state to this function.
184+ return {"checkpoint_path" : "path/to/saved/checkpoint" }
185+
186+
168187# This function will be called by default. Note that we could in principle manually pass an experiment instance,
169188# e.g., obtained by loading a model from the database or by calling this from a Jupyter notebook.
170189@ex .automain
0 commit comments