Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
[
"AllegroTransformer3DModel",
"AsymmetricAutoencoderKL",
"AttentionProvider",
"AuraFlowTransformer2DModel",
"AutoencoderDC",
"AutoencoderKL",
Expand Down Expand Up @@ -212,6 +213,7 @@
"UVit2DModel",
"VQModel",
"WanTransformer3DModel",
"attention_provider",
]
)
_import_structure["optimization"] = [
Expand Down Expand Up @@ -738,6 +740,7 @@
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
AttentionProvider,
AuraFlowTransformer2DModel,
AutoencoderDC,
AutoencoderKL,
Expand Down Expand Up @@ -806,6 +809,7 @@
UVit2DModel,
VQModel,
WanTransformer3DModel,
attention_provider,
)
from .optimization import (
get_constant_schedule,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["attention_dispatch"] = ["AttentionProvider", "attention_provider"]
_import_structure["auto_model"] = ["AutoModel"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
Expand Down Expand Up @@ -106,6 +107,7 @@
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter
from .attention_dispatch import AttentionProvider, attention_provider
from .auto_model import AutoModel
from .autoencoders import (
AsymmetricAutoencoderKL,
Expand Down
20 changes: 20 additions & 0 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import functools
import inspect
from enum import Enum
Expand Down Expand Up @@ -123,6 +124,8 @@ class AttentionProvider(str, Enum):
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
_SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda"
_SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton"
# TODO: let's not add support for Sparge Attention now because it requires tuning per model
# We can look into supporting something "autotune"-ing in the future
# SPARGE = "sparge"

# `xformers`
Expand Down Expand Up @@ -157,6 +160,23 @@ def list_providers(cls):
return list(cls._providers.keys())


@contextlib.contextmanager
def attention_provider(provider: AttentionProvider = AttentionProvider.NATIVE):
"""
Context manager to set the active attention provider.
"""
if provider not in _AttentionProviderRegistry._providers:
raise ValueError(f"Provider {provider} is not registered.")

old_provider = _AttentionProviderRegistry._active_provider
_AttentionProviderRegistry._active_provider = provider

try:
yield
finally:
_AttentionProviderRegistry._active_provider = old_provider


def attention_dispatch(
query: torch.Tensor,
key: torch.Tensor,
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def prepare_latents(
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
)
video_condition = video_condition.to(device=device, dtype=dtype)
video_condition = video_condition.to(device=device, dtype=self.vae.dtype)

latents_mean = (
torch.tensor(self.vae.config.latents_mean)
Expand All @@ -421,6 +421,7 @@ def prepare_latents(
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)

latent_condition = latent_condition.to(dtype=dtype)
latent_condition = (latent_condition - latents_mean) * latents_std

mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
Expand Down
19 changes: 19 additions & 0 deletions src/diffusers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class AttentionProvider(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class AuraFlowTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]

Expand Down Expand Up @@ -1105,6 +1120,10 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


def attention_provider(*args, **kwargs):
requires_backends(attention_provider, ["torch"])


def get_constant_schedule(*args, **kwargs):
requires_backends(get_constant_schedule, ["torch"])

Expand Down