Skip to content

Commit 352805d

Browse files
committed
fix for python 3.9
1 parent 4145e27 commit 352805d

File tree

1 file changed

+22
-21
lines changed

1 file changed

+22
-21
lines changed

ldm/models/diffusion/ddpm.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -614,30 +614,11 @@ def configure_optimizers(self):
614614
class LatentDiffusion(DDPM):
615615
"""main class"""
616616

617-
@staticmethod
618-
def _fallback_personalization_config()->dict:
619-
"""
620-
This protects us against custom legacy config files that
621-
don't contain the personalization_config section.
622-
"""
623-
return OmegaConf.create(
624-
dict(
625-
target='ldm.modules.embedding_manager.EmbeddingManager',
626-
params=dict(
627-
placeholder_strings=list('*'),
628-
initializer_words=list('sculpture'),
629-
per_image_tokens=False,
630-
num_vectors_per_token=1,
631-
progressive_words=False,
632-
)
633-
)
634-
)
635-
636617
def __init__(
637618
self,
638619
first_stage_config,
639620
cond_stage_config,
640-
personalization_config=_fallback_personalization_config(),
621+
personalization_config=None,
641622
num_timesteps_cond=None,
642623
cond_stage_key='image',
643624
cond_stage_trainable=False,
@@ -695,7 +676,8 @@ def __init__(
695676
self.model.train = disabled_train
696677
for param in self.model.parameters():
697678
param.requires_grad = False
698-
679+
680+
personalization_config = personalization_config or self._fallback_personalization_config()
699681
self.embedding_manager = self.instantiate_embedding_manager(
700682
personalization_config, self.cond_stage_model
701683
)
@@ -2170,6 +2152,25 @@ def on_save_checkpoint(self, checkpoint):
21702152

21712153
self.emb_ckpt_counter += 500
21722154

2155+
@classmethod
2156+
def _fallback_personalization_config(self)->dict:
2157+
"""
2158+
This protects us against custom legacy config files that
2159+
don't contain the personalization_config section.
2160+
"""
2161+
return OmegaConf.create(
2162+
dict(
2163+
target='ldm.modules.embedding_manager.EmbeddingManager',
2164+
params=dict(
2165+
placeholder_strings=list('*'),
2166+
initializer_words=list('sculpture'),
2167+
per_image_tokens=False,
2168+
num_vectors_per_token=1,
2169+
progressive_words=False,
2170+
)
2171+
)
2172+
)
2173+
21732174

21742175
class DiffusionWrapper(pl.LightningModule):
21752176
def __init__(self, diff_model_config, conditioning_key):

0 commit comments

Comments
 (0)