Skip to content

Commit 865fe1b

Browse files
committed
up
1 parent ff1ab1a commit 865fe1b

File tree

1 file changed

+5
-63
lines changed

1 file changed

+5
-63
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_flux2.py

Lines changed: 5 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15-
from typing import Dict, Optional, Tuple, Union
15+
from typing import Optional, Tuple, Union
1616

1717
import torch
1818
import torch.nn as nn
@@ -22,11 +22,11 @@
2222
from ...loaders.single_file_model import FromOriginalModelMixin
2323
from ...utils import deprecate
2424
from ...utils.accelerate_utils import apply_forward_hook
25+
from ..attention import AttentionMixin
2526
from ..attention_processor import (
2627
ADDED_KV_ATTENTION_PROCESSORS,
2728
CROSS_ATTENTION_PROCESSORS,
2829
Attention,
29-
AttentionProcessor,
3030
AttnAddedKVProcessor,
3131
AttnProcessor,
3232
FusedAttnProcessor2_0,
@@ -36,7 +36,9 @@
3636
from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
3737

3838

39-
class AutoencoderKLFlux2(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
39+
class AutoencoderKLFlux2(
40+
ModelMixin, AutoencoderMixin, AttentionMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin
41+
):
4042
r"""
4143
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
4244
@@ -154,66 +156,6 @@ def __init__(
154156
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
155157
self.tile_overlap_factor = 0.25
156158

157-
@property
158-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
159-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
160-
r"""
161-
Returns:
162-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
163-
indexed by its weight name.
164-
"""
165-
# set recursively
166-
processors = {}
167-
168-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
169-
if hasattr(module, "get_processor"):
170-
processors[f"{name}.processor"] = module.get_processor()
171-
172-
for sub_name, child in module.named_children():
173-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
174-
175-
return processors
176-
177-
for name, module in self.named_children():
178-
fn_recursive_add_processors(name, module, processors)
179-
180-
return processors
181-
182-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
183-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
184-
r"""
185-
Sets the attention processor to use to compute attention.
186-
187-
Parameters:
188-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
189-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
190-
for **all** `Attention` layers.
191-
192-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
193-
processor. This is strongly recommended when setting trainable attention processors.
194-
195-
"""
196-
count = len(self.attn_processors.keys())
197-
198-
if isinstance(processor, dict) and len(processor) != count:
199-
raise ValueError(
200-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
201-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
202-
)
203-
204-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
205-
if hasattr(module, "set_processor"):
206-
if not isinstance(processor, dict):
207-
module.set_processor(processor)
208-
else:
209-
module.set_processor(processor.pop(f"{name}.processor"))
210-
211-
for sub_name, child in module.named_children():
212-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
213-
214-
for name, module in self.named_children():
215-
fn_recursive_attn_processor(name, module, processor)
216-
217159
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
218160
def set_default_attn_processor(self):
219161
"""

0 commit comments

Comments
 (0)