1+ from typing import Union
2+
13from ..utils import get_logger
24from .import_utils import is_kernels_available
35
@@ -21,3 +23,43 @@ def _get_fa3_from_hub():
2123 except Exception as e :
2224 logger .error (f"An error occurred while fetching kernel '{ _DEFAULT_HUB_ID_FA3 } ' from the Hub: { e } " )
2325 raise
26+
27+
28+ if is_kernels_available ():
29+ from kernels import (
30+ Device ,
31+ LayerRepository ,
32+ register_kernel_mapping ,
33+ replace_kernel_forward_from_hub ,
34+ use_kernel_forward_from_hub ,
35+ )
36+
37+ _KERNEL_MAPPING : dict [str , dict [Union [Device , str ], LayerRepository ]] = {
38+ "RMSNorm" : {
39+ "cuda" : LayerRepository (repo_id = "kernels-community/liger_kernels" , layer_name = "LigerRMSNorm" ),
40+ },
41+ "MLP" : {"cuda" : LayerRepository (repo_id = "medmekk/triton-llama-mlp" , layer_name = "TritonLlamaMLP" )},
42+ }
43+
44+ register_kernel_mapping (_KERNEL_MAPPING )
45+
46+ else :
47+ # Stub to make decorators int transformers work when `kernels`
48+ # is not installed.
49+ def use_kernel_forward_from_hub (* args , ** kwargs ):
50+ def decorator (cls ):
51+ return cls
52+
53+ return decorator
54+
55+ class LayerRepository :
56+ def __init__ (self , * args , ** kwargs ):
57+ raise RuntimeError ("LayerRepository requires `kernels` to be installed. Run `pip install kernels`." )
58+
59+ def replace_kernel_forward_from_hub (* args , ** kwargs ):
60+ raise RuntimeError (
61+ "replace_kernel_forward_from_hub requires `kernels` to be installed. Run `pip install kernels`."
62+ )
63+
64+ def register_kernel_mapping (* args , ** kwargs ):
65+ raise RuntimeError ("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`." )
0 commit comments