Skip to content

Commit 0b143c1

Browse files
committed
Separate .optim file from model
1 parent 7ea5956 commit 0b143c1

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

modules/hypernetworks/hypernetwork.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def weights(self):
161161

162162
def save(self, filename):
163163
state_dict = {}
164+
optimizer_saved_dict = {}
164165

165166
for k, v in self.layers.items():
166167
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
@@ -175,9 +176,10 @@ def save(self, filename):
175176
state_dict['sd_checkpoint'] = self.sd_checkpoint
176177
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
177178
if self.optimizer_name is not None:
178-
state_dict['optimizer_name'] = self.optimizer_name
179+
optimizer_saved_dict['optimizer_name'] = self.optimizer_name
179180
if self.optimizer_state_dict:
180-
state_dict['optimizer_state_dict'] = self.optimizer_state_dict
181+
optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
182+
torch.save(optimizer_saved_dict, filename + '.optim')
181183

182184
torch.save(state_dict, filename)
183185

@@ -198,9 +200,11 @@ def load(self, filename):
198200
print(f"Layer norm is set to {self.add_layer_norm}")
199201
self.use_dropout = state_dict.get('use_dropout', False)
200202
print(f"Dropout usage is set to {self.use_dropout}")
201-
self.optimizer_name = state_dict.get('optimizer_name', 'AdamW')
203+
204+
optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
205+
self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
202206
print(f"Optimizer name is {self.optimizer_name}")
203-
self.optimizer_state_dict = state_dict.get('optimizer_state_dict', None)
207+
self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
204208
if self.optimizer_state_dict:
205209
print("Loaded existing optimizer from checkpoint")
206210
else:

0 commit comments

Comments
 (0)