Skip to content

Commit 59de0a3

Browse files
committed
remove autocast and some assert
1 parent c6eb233 commit 59de0a3

File tree

1 file changed

+17
-15
lines changed
  • src/diffusers/models/autoencoders

1 file changed

+17
-15
lines changed

src/diffusers/models/autoencoders/dc_ae.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def __init__(
211211
scales: tuple[int, ...] = (5,),
212212
eps=1.0e-15,
213213
):
214-
super(LiteMLA, self).__init__()
214+
super().__init__()
215215
self.eps = eps
216216
heads = int(in_channels // dim * heads_ratio) if heads is None else heads
217217

@@ -253,7 +253,6 @@ def __init__(
253253
act_func=act_func[1],
254254
)
255255

256-
@torch.autocast(device_type="cuda", enabled=False)
257256
def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
258257
B, _, H, W = list(qkv.size())
259258

@@ -292,7 +291,6 @@ def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
292291
out = torch.reshape(out, (B, -1, H, W))
293292
return out
294293

295-
@torch.autocast(device_type="cuda", enabled=False)
296294
def relu_quadratic_att(self, qkv: torch.Tensor) -> torch.Tensor:
297295
B, _, H, W = list(qkv.size())
298296

@@ -657,11 +655,12 @@ def __init__(
657655
super().__init__()
658656
num_stages = len(width_list)
659657
self.num_stages = num_stages
660-
assert len(depth_list) == num_stages
661-
assert len(width_list) == num_stages
662-
assert isinstance(block_type, str) or (
663-
isinstance(block_type, list) and len(block_type) == num_stages
664-
)
658+
659+
# validate config
660+
if len(depth_list) != num_stages or len(width_list) != num_stages:
661+
raise ValueError(f"len(depth_list) {len(depth_list)} and len(width_list) {len(width_list)} should be equal to num_stages {num_stages}")
662+
if not isinstance(block_type, (str, list)) or (isinstance(block_type, list) and len(block_type) != num_stages):
663+
raise ValueError(f"block_type should be either a str or a list of str with length {num_stages}, but got {block_type}")
665664

666665
self.project_in = build_encoder_project_in_block(
667666
in_channels=in_channels,
@@ -725,13 +724,16 @@ def __init__(
725724
super().__init__()
726725
num_stages = len(width_list)
727726
self.num_stages = num_stages
728-
assert len(depth_list) == num_stages
729-
assert len(width_list) == num_stages
730-
assert isinstance(block_type, str) or (
731-
isinstance(block_type, list) and len(block_type) == num_stages
732-
)
733-
assert isinstance(norm, str) or (isinstance(norm, list) and len(norm) == num_stages)
734-
assert isinstance(act, str) or (isinstance(act, list) and len(act) == num_stages)
727+
728+
# validate config
729+
if len(depth_list) != num_stages or len(width_list) != num_stages:
730+
raise ValueError(f"len(depth_list) {len(depth_list)} and len(width_list) {len(width_list)} should be equal to num_stages {num_stages}")
731+
if not isinstance(block_type, (str, list)) or (isinstance(block_type, list) and len(block_type) != num_stages):
732+
raise ValueError(f"block_type should be either a str or a list of str with length {num_stages}, but got {block_type}")
733+
if not isinstance(norm, (str, list)) or (isinstance(norm, list) and len(norm) != num_stages):
734+
raise ValueError(f"norm should be either a str or a list of str with length {num_stages}, but got {norm}")
735+
if not isinstance(act, (str, list)) or (isinstance(act, list) and len(act) != num_stages):
736+
raise ValueError(f"act should be either a str or a list of str with length {num_stages}, but got {act}")
735737

736738
self.project_in = build_decoder_project_in_block(
737739
in_channels=latent_channels,

0 commit comments

Comments
 (0)