Skip to content

Commit 9a09162

Browse files
committed
up
1 parent 33a8a3b commit 9a09162

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323
from ...configuration_utils import ConfigMixin, register_to_config
2424
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
25-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25+
from ...utils import USE_PEFT_BACKEND, is_kernels_available, logging, scale_lora_layers, unscale_lora_layers
26+
from ...utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
2627
from ...utils.torch_utils import maybe_allow_in_graph
2728
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
2829
from ..attention_dispatch import dispatch_attention_fn
@@ -40,6 +41,12 @@
4041

4142
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4243

44+
if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS:
45+
from kernels import get_kernel
46+
47+
activation = get_kernel("kernels-community/activation", revision="add_more_act")
48+
gelu_tanh_kernel = activation.gelu_tanh
49+
4350

4451
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
4552
query = attn.to_q(hidden_states)
@@ -350,7 +357,11 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
350357

351358
self.norm = AdaLayerNormZeroSingle(dim)
352359
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
353-
self.act_mlp = nn.GELU(approximate="tanh")
360+
if not DIFFUSERS_ENABLE_HUB_KERNELS:
361+
self.act_mlp = nn.GELU(approximate="tanh")
362+
else:
363+
self.act_mlp = gelu_tanh_kernel
364+
354365
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
355366

356367
self.attn = FluxAttention(

0 commit comments

Comments
 (0)