Skip to content

Commit 1764ac3

Browse files
committed
use hash to check valid optim
1 parent 0b143c1 commit 1764ac3

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,13 @@ def save(self, filename):
177177
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
178178
if self.optimizer_name is not None:
179179
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
180+
181+
torch.save(state_dict, filename)
180182
if self.optimizer_state_dict:
183+
optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
181184
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
182185
torch.save(optimizer_saved_dict, filename + '.optim')
183186

184-
torch.save(state_dict, filename)
185-
186187
def load(self, filename):
187188
self.filename = filename
188189
if self.name is None:
@@ -204,7 +205,10 @@ def load(self, filename):
204205
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
205206
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
206207
print(f"Optimizer name is {self.optimizer_name}")
207-
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
208+
if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
209+
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
210+
else:
211+
self.optimizer_state_dict = None
208212
if self.optimizer_state_dict:
209213
print("Loaded existing optimizer from checkpoint")
210214
else:
@@ -229,7 +233,7 @@ def list_hypernetworks(path):
229233
name = os.path.splitext(os.path.basename(filename))[0]
230234
# Prevent a hypothetical "None.pt" from being listed.
231235
if name != "None":
232-
res[name] = filename
236+
res[name + f"({sd_models.model_hash(filename)})"] = filename
233237
return res
234238

235239

@@ -375,6 +379,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
375379
else:
376380
hypernetwork_dir = None
377381

382+
hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
378383
if create_image_every > 0:
379384
images_dir = os.path.join(log_directory, "images")
380385
os.makedirs(images_dir, exist_ok=True)

0 commit comments

Comments
 (0)