Skip to content

Commit 142ed1a

Browse files
committed
add a save and load function that acts on the wired stable diffusion, and make it clear to beginner where the 100kb of data per concept is stored
1 parent 183cd32 commit 142ed1a

File tree

3 files changed

+84
-1
lines changed

3 files changed

+84
-1
lines changed

perfusion_pytorch/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,8 @@
1010
EmbeddingWrapper,
1111
merge_embedding_wrappers
1212
)
13+
14+
from perfusion_pytorch.save_load import (
15+
save,
16+
load
17+
)

perfusion_pytorch/save_load.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from pathlib import Path
2+
3+
import torch
4+
from torch import nn
5+
from torch.nn import Module
6+
7+
from beartype import beartype
8+
9+
from perfusion_pytorch.embedding import EmbeddingWrapper
10+
from perfusion_pytorch.perfusion import Rank1EditModule
11+
12+
def exists(val):
13+
return val is not None
14+
15+
@beartype
16+
def save(
17+
text_image_model: Module,
18+
path: str
19+
):
20+
path = Path(path)
21+
path.parents[0].mkdir(exist_ok = True, parents = True)
22+
23+
embed_params = None
24+
key_value_params = []
25+
C_inv = None
26+
27+
for module in text_image_model.modules():
28+
if isinstance(module, EmbeddingWrapper):
29+
assert not exists(embed_params), 'there should only be one wrapped EmbeddingWrapper'
30+
embed_params = module.concepts.data
31+
32+
elif isinstance(module, Rank1EditModule):
33+
key_value_params.append([
34+
module.ema_concept_text_encs.data,
35+
module.concept_outputs.data
36+
])
37+
38+
C_inv = module.C_inv.data
39+
40+
assert exists(C_inv), 'Rank1EditModule not found. you likely did not wire up the text to image model correctly'
41+
42+
pkg = dict(
43+
embed_params = embed_params,
44+
key_value_params = key_value_params,
45+
C_inv = C_inv
46+
)
47+
48+
torch.save(pkg, f'{str(path)}')
49+
print(f'saved to {str(path)}')
50+
51+
@beartype
52+
def load(
53+
text_image_model: Module,
54+
path: str
55+
):
56+
path = Path(path)
57+
assert path.exists(), f'file not found at {str(path)}'
58+
59+
pkg = torch.load(str(path))
60+
61+
embed_params = pkg['embed_params']
62+
key_value_params = pkg['key_value_params']
63+
C_inv = pkg['C_inv']
64+
65+
for module in text_image_model.modules():
66+
if isinstance(module, EmbeddingWrapper):
67+
module.concepts.data.copy_(embed_params)
68+
69+
elif isinstance(module, Rank1EditModule):
70+
assert len(key_value_params) > 0, 'mismatch between what was saved vs what is being loaded'
71+
concept_input, concept_output = key_value_params.pop(0)
72+
module.ema_concept_text_encs.data.copy_(concept_input)
73+
module.concept_outputs.data.copy_(concept_output)
74+
75+
module.C_inv.copy_(C_inv)
76+
module.initted.copy_(torch.tensor([True]))
77+
78+
print(f'loaded concept params from {str(path)}')

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.7',
6+
version = '0.1.8',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)