|
22 | 22 |
|
23 | 23 | from ...configuration_utils import ConfigMixin, register_to_config |
24 | 24 | from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin |
25 | | -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers |
| 25 | +from ...utils import USE_PEFT_BACKEND, is_kernels_available, logging, scale_lora_layers, unscale_lora_layers |
| 26 | +from ...utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS |
26 | 27 | from ...utils.torch_utils import maybe_allow_in_graph |
27 | 28 | from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward |
28 | 29 | from ..attention_dispatch import dispatch_attention_fn |
|
40 | 41 |
|
41 | 42 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
42 | 43 |
|
| 44 | +if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS: |
| 45 | + from kernels import get_kernel |
| 46 | + |
| 47 | + activation = get_kernel("kernels-community/activation", revision="add_more_act") |
| 48 | + gelu_tanh_kernel = activation.gelu_tanh |
| 49 | + |
43 | 50 |
|
44 | 51 | def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): |
45 | 52 | query = attn.to_q(hidden_states) |
@@ -350,7 +357,11 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, |
350 | 357 |
|
351 | 358 | self.norm = AdaLayerNormZeroSingle(dim) |
352 | 359 | self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) |
353 | | - self.act_mlp = nn.GELU(approximate="tanh") |
| 360 | + if not DIFFUSERS_ENABLE_HUB_KERNELS: |
| 361 | + self.act_mlp = nn.GELU(approximate="tanh") |
| 362 | + else: |
| 363 | + self.act_mlp = gelu_tanh_kernel |
| 364 | + |
354 | 365 | self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) |
355 | 366 |
|
356 | 367 | self.attn = FluxAttention( |
|
0 commit comments