Skip to content

Commit 50c0b78

Browse files
committed
start kernelize.
1 parent f5c113e commit 50c0b78

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed

src/diffusers/models/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from ..utils import deprecate, logging
2222
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
23+
from ..utils.kernels_utils import use_kernel_forward_from_hub
2324
from ..utils.torch_utils import maybe_allow_in_graph
2425
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
2526
from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
@@ -1669,6 +1670,7 @@ def forward(
16691670
return hidden_states
16701671

16711672

1673+
@use_kernel_forward_from_hub("MLP")
16721674
class FeedForward(nn.Module):
16731675
r"""
16741676
A feed-forward layer.

src/diffusers/models/normalization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch.nn.functional as F
2222

2323
from ..utils import is_torch_npu_available, is_torch_version
24+
from ..utils.kernels_utils import use_kernel_forward_from_hub
2425
from .activations import get_activation
2526
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
2627

@@ -508,6 +509,7 @@ def forward(self, input):
508509
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
509510

510511

512+
@use_kernel_forward_from_hub("RMSNorm")
511513
class RMSNorm(nn.Module):
512514
r"""
513515
RMS Norm as introduced in https://huggingface.co/papers/1910.07467 by Zhang et al.

src/diffusers/utils/kernels_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Union
2+
13
from ..utils import get_logger
24
from .import_utils import is_kernels_available
35

@@ -21,3 +23,43 @@ def _get_fa3_from_hub():
2123
except Exception as e:
2224
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
2325
raise
26+
27+
28+
if is_kernels_available():
29+
from kernels import (
30+
Device,
31+
LayerRepository,
32+
register_kernel_mapping,
33+
replace_kernel_forward_from_hub,
34+
use_kernel_forward_from_hub,
35+
)
36+
37+
_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
38+
"RMSNorm": {
39+
"cuda": LayerRepository(repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm"),
40+
},
41+
"MLP": {"cuda": LayerRepository(repo_id="medmekk/triton-llama-mlp", layer_name="TritonLlamaMLP")},
42+
}
43+
44+
register_kernel_mapping(_KERNEL_MAPPING)
45+
46+
else:
47+
# Stub to make decorators int transformers work when `kernels`
48+
# is not installed.
49+
def use_kernel_forward_from_hub(*args, **kwargs):
50+
def decorator(cls):
51+
return cls
52+
53+
return decorator
54+
55+
class LayerRepository:
56+
def __init__(self, *args, **kwargs):
57+
raise RuntimeError("LayerRepository requires `kernels` to be installed. Run `pip install kernels`.")
58+
59+
def replace_kernel_forward_from_hub(*args, **kwargs):
60+
raise RuntimeError(
61+
"replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
62+
)
63+
64+
def register_kernel_mapping(*args, **kwargs):
65+
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")

0 commit comments

Comments
 (0)