Skip to content

Commit eb64d52

Browse files
committed
make style
1 parent 0bda5c5 commit eb64d52

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

scripts/convert_dcae_to_diffusers.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def convert_ae(config_name: str, dtype: torch.dtype):
9999
hub_id = f"mit-han-lab/{config_name}"
100100
ckpt_path = hf_hub_download(hub_id, "model.safetensors")
101101
original_state_dict = get_state_dict(load_file(ckpt_path))
102-
102+
103103
ae = AutoencoderDC(**config).to(dtype=dtype)
104104

105105
for key in list(original_state_dict.keys()):
@@ -122,8 +122,22 @@ def get_ae_config(name: str):
122122
if name in ["dc-ae-f32c32-sana-1.0"]:
123123
config = {
124124
"latent_channels": 32,
125-
"encoder_block_types": ("ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock"),
126-
"decoder_block_types": ("ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock", "EfficientViTBlock", "EfficientViTBlock"),
125+
"encoder_block_types": (
126+
"ResBlock",
127+
"ResBlock",
128+
"ResBlock",
129+
"EfficientViTBlock",
130+
"EfficientViTBlock",
131+
"EfficientViTBlock",
132+
),
133+
"decoder_block_types": (
134+
"ResBlock",
135+
"ResBlock",
136+
"ResBlock",
137+
"EfficientViTBlock",
138+
"EfficientViTBlock",
139+
"EfficientViTBlock",
140+
),
127141
"encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
128142
"decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
129143
"encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
8888
hidden_states = self.conv1(hidden_states)
8989
hidden_states = self.nonlinearity(hidden_states)
9090
hidden_states = self.conv2(hidden_states)
91-
91+
9292
if self.norm_type == "rms_norm":
9393
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
9494
else:
@@ -230,7 +230,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
230230
hidden_states = self.quadratic_attention(qkv)
231231

232232
hidden_states = self.proj_out(hidden_states)
233-
233+
234234
if self.norm_type == "rms_norm":
235235
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
236236
else:

0 commit comments

Comments
 (0)