Skip to content

Commit da7b8ce

Browse files
authored
[kernels] Kernel Config (#41232)
* first config * add kernel_config * add import logic * fixing style * compare class name * add comments * rm import * adding kernel md files * add to toctree * adding to main_classes * simplify required config * add to doc * style * store the mapping * remove nested func * add hub mixin * fix * imports * fix
1 parent 4763b8c commit da7b8ce

File tree

7 files changed

+266
-4
lines changed

7 files changed

+266
-4
lines changed

docs/source/en/_toctree.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,11 @@
216216
- local: quantization/contribute
217217
title: Contribute
218218
title: Quantization
219+
- isExpanded: false
220+
sections:
221+
- local: kernel_doc/overview
222+
title: Kernels in transformers
223+
title: Kernels
219224
- isExpanded: false
220225
sections:
221226
- local: serialization
@@ -368,6 +373,8 @@
368373
title: Image Processor
369374
- local: main_classes/video_processor
370375
title: Video Processor
376+
- local: main_classes/kernels
377+
title: Kernels
371378
title: Main Classes
372379
- sections:
373380
- sections:
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Overview
2+
3+
Kernels in transformers are used to optimize the performance of models with custom layers from the hub and very low effort.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
## Kernels
2+
3+
This page documents the kernels configuration utilities.
4+
5+
### KernelConfig
6+
7+
[[autodoc]] KernelConfig

src/transformers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@
265265
"VptqConfig",
266266
],
267267
"video_utils": [],
268+
"utils.kernel_config": ["KernelConfig"],
268269
}
269270

270271
# tokenizers-backed objects
@@ -754,6 +755,7 @@
754755
from .utils import is_torch_npu_available as is_torch_npu_available
755756
from .utils import is_torch_xla_available as is_torch_xla_available
756757
from .utils import is_torch_xpu_available as is_torch_xpu_available
758+
from .utils.kernel_config import KernelConfig as KernelConfig
757759

758760
# bitsandbytes config
759761
from .utils.quantization_config import AqlmConfig as AqlmConfig
@@ -775,7 +777,6 @@
775777
from .utils.quantization_config import TorchAoConfig as TorchAoConfig
776778
from .utils.quantization_config import VptqConfig as VptqConfig
777779
from .video_processing_utils import BaseVideoProcessor as BaseVideoProcessor
778-
779780
else:
780781
import sys
781782

src/transformers/modeling_utils.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
WEIGHTS_INDEX_NAME,
8585
WEIGHTS_NAME,
8686
ContextManagers,
87+
KernelConfig,
8788
PushToHubMixin,
8889
cached_file,
8990
check_torch_load_is_safe,
@@ -4503,6 +4504,7 @@ def from_pretrained(
45034504
device_mesh = kwargs.pop("device_mesh", None)
45044505
trust_remote_code = kwargs.pop("trust_remote_code", None)
45054506
use_kernels = kwargs.pop("use_kernels", False)
4507+
kernel_config = kwargs.pop("kernel_config", None)
45064508

45074509
key_mapping = kwargs.pop("key_mapping", None)
45084510
# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
@@ -4895,7 +4897,26 @@ def _assign_original_dtype(module):
48954897

48964898
# check if using kernels
48974899
if use_kernels:
4898-
model.use_kernels = True
4900+
if not is_kernels_available():
4901+
raise ValueError(
4902+
"Kernels are not available. To use kernels, please install kernels using `pip install kernels`"
4903+
)
4904+
from kernels import use_kernel_mapping
4905+
4906+
if kernel_config is not None and isinstance(kernel_config, KernelConfig):
4907+
# This will make sure the mapping is valid, and the layers are registered in the model
4908+
kernel_config.sanitize_kernel_mapping(model)
4909+
4910+
# This will create a compatible mapping for the model with the kernels library
4911+
kernel_config.create_compatible_mapping(model)
4912+
4913+
# This is a context manager to override the default kernel mapping
4914+
# We are calling kernelize inside this context manager using the use_kernels setter
4915+
with use_kernel_mapping(kernel_config.kernel_mapping):
4916+
model.use_kernels = True
4917+
# We use the default kernel mapping in .integrations.hub_kernels
4918+
else:
4919+
model.use_kernels = True
48994920

49004921
# If it is a model with generation capabilities, attempt to load generation files (generation config,
49014922
# custom generate function)
@@ -5506,14 +5527,14 @@ def loss_function(self):
55065527
def loss_function(self, value):
55075528
self._loss_function = value
55085529

5509-
def kernelize(self):
5530+
def kernelize(self, mode=None):
55105531
if not is_kernels_available():
55115532
raise ValueError(
55125533
"Kernels are not available. To use kernels, please install kernels using `pip install kernels`"
55135534
)
55145535
from kernels import Device, Mode, kernelize
55155536

5516-
mode = Mode.INFERENCE if not self.training else Mode.TRAINING
5537+
mode = Mode.INFERENCE if not self.training else Mode.TRAINING if mode is None else mode
55175538
kernelize(self, device=Device(type=self.device.type), mode=mode)
55185539
self._use_kernels = True
55195540

src/transformers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@
252252
requires_backends,
253253
torch_only_method,
254254
)
255+
from .kernel_config import KernelConfig
255256
from .peft_utils import (
256257
ADAPTER_CONFIG_NAME,
257258
ADAPTER_SAFE_WEIGHTS_NAME,
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from ..utils import PushToHubMixin, is_kernels_available, is_torch_available
16+
17+
18+
if is_kernels_available():
19+
from kernels import LayerRepository, Mode
20+
21+
if is_torch_available():
22+
import torch
23+
24+
25+
def infer_device(model):
26+
"""
27+
Infers the device type from the model parameters.
28+
Args:
29+
model: The model instance.
30+
31+
Returns:
32+
The device type.
33+
"""
34+
EXAMPLE_MAPPING = """
35+
{
36+
"RMSNorm": {
37+
"cuda":
38+
"kernels-community/layer_norm:LlamaRMSNorm",
39+
...
40+
},
41+
...
42+
}
43+
"""
44+
try:
45+
param = next(model.parameters())
46+
except StopIteration:
47+
raise ValueError(
48+
f"Cannot determine model device, please provide a device to the mapping. Example: {EXAMPLE_MAPPING}"
49+
)
50+
51+
dev_type = param.device.type
52+
if dev_type == "cuda":
53+
# Refine based on actual platform
54+
if torch.version.hip is not None:
55+
return "rocm"
56+
57+
return dev_type
58+
59+
60+
def add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping):
61+
if device not in ["cuda", "rocm", "xpu"]:
62+
raise ValueError(f"Only cuda, rocm, and xpu devices supported, got: {device}")
63+
repo_layer_name = repo_name.split(":")[1]
64+
repo_id = repo_name.split(":")[0]
65+
compatible_mapping[layer_name] = {
66+
device: {
67+
mode: LayerRepository(
68+
repo_id=repo_id,
69+
layer_name=repo_layer_name,
70+
)
71+
}
72+
}
73+
74+
75+
class KernelConfig(PushToHubMixin):
76+
"""
77+
Kernel configuration class. This class is used to configure the kernel mapping for a model.
78+
"""
79+
80+
def __init__(self, kernel_mapping={}):
81+
self.kernel_mapping = kernel_mapping
82+
self.registered_layer_names = {}
83+
84+
def update_kernel(self, repo_id, registered_name, layer_name, device, mode, revision=None):
85+
self.kernel_mapping[registered_name] = {
86+
device: {
87+
mode: LayerRepository(
88+
repo_id=repo_id,
89+
layer_name=layer_name,
90+
revision=revision,
91+
)
92+
}
93+
}
94+
95+
def store_registered_layer_names(self, model):
96+
for name, module in model.named_modules():
97+
if hasattr(module, "kernel_layer_name"):
98+
self.registered_layer_names[name] = module.kernel_layer_name
99+
100+
def sanitize_kernel_mapping(self, model):
101+
"""
102+
Validates the kernel_mapping to ensure that:
103+
1. Each layer_name in the mapping is registered in the model (i.e., the model contains a module with a matching kernel_layer_name).
104+
2. Each kernel value is either a string of the form 'org/repo:layer_name' or a dict mapping device types ("cuda", "rocm", "xpu") to such strings.
105+
3. Each device key in a dict is one of "cuda", "rocm", or "xpu".
106+
4. Each repo_name is a valid repository and layer name in the format 'org/repo:layer_name' (i.e., a string containing both a slash and a colon).
107+
108+
Args:
109+
model: The model instance whose modules are checked for registered kernel_layer_name attributes.
110+
111+
Raises:
112+
ValueError: If a layer_name is not registered in the model, if a device is not supported,
113+
or if a repo_name is not a valid 'org/repo:layer_name' string.
114+
"""
115+
MAPPING_FORMAT = """
116+
{
117+
"RMSNorm":
118+
"kernels-community/layer_norm:LlamaRMSNorm",
119+
...
120+
},
121+
122+
or
123+
124+
{
125+
"RMSNorm": {
126+
"cuda":
127+
"kernels-community/layer_norm:LlamaRMSNorm",
128+
"rocm":
129+
"kernels-community/layer_norm:LlamaRMSNorm",
130+
...
131+
},
132+
...
133+
}
134+
"""
135+
self.store_registered_layer_names(model)
136+
# Validate that the kernel mapping is a dict
137+
if not isinstance(self.kernel_mapping, dict):
138+
raise ValueError(
139+
f"Kernel mapping must be a dict of the following format: {MAPPING_FORMAT}, got: {type(self.kernel_mapping)}"
140+
)
141+
142+
for layer_name, kernel in self.kernel_mapping.items():
143+
if layer_name not in self.registered_layer_names.values():
144+
raise ValueError(
145+
f"Layer {layer_name} is not registered in the model, please register it first using register_kernel_forward_from_hub"
146+
)
147+
148+
if isinstance(kernel, str):
149+
if "/" not in kernel or ":" not in kernel:
150+
raise ValueError(
151+
f"Kernel mapping for '{layer_name}' must be a valid repo name with a layer name (e.g., 'org/repo:layer_name'), got: {kernel}"
152+
)
153+
154+
elif isinstance(kernel, dict):
155+
for device, repo_name in kernel.items():
156+
if device not in ["cuda", "rocm", "xpu"]:
157+
raise ValueError(f"Only cuda, rocm, and xpu devices supported, got: {device}")
158+
159+
if not isinstance(repo_name, str) or "/" not in repo_name or ":" not in repo_name:
160+
raise ValueError(
161+
f"Kernel mapping for '{layer_name}' must be a valid repo name with a layer name (e.g., 'org/repo:layer_name'), got: {repo_name}"
162+
)
163+
164+
else:
165+
raise ValueError(f"Kernel mapping must follow the format: {MAPPING_FORMAT}, got: {kernel}")
166+
167+
def create_compatible_mapping(self, model, compile=False):
168+
"""
169+
Transforms a simple kernel_mapping of the form:
170+
{
171+
"RMSNorm":
172+
"kernels-community/layer_norm:LlamaRMSNorm",
173+
...
174+
},
175+
176+
or
177+
178+
{
179+
"RMSNorm": {
180+
"cuda":
181+
"kernels-community/layer_norm:LlamaRMSNorm",
182+
"rocm":
183+
"kernels-community/layer_norm:LlamaRMSNorm",
184+
...
185+
},
186+
...
187+
}
188+
189+
into a nested mapping:
190+
191+
{
192+
"RMSNorm": {
193+
"cuda": {
194+
Mode.INFERENCE: LayerRepository(
195+
repo_id="kernels-community/layer_norm",
196+
layer_name="LlamaRMSNorm",
197+
)
198+
}
199+
}
200+
}
201+
202+
that's compatible with the kernels library.
203+
204+
The device is inferred from the model's parameters if not provided.
205+
The Mode is inferred from the model's training state.
206+
"""
207+
compatible_mapping = {}
208+
for layer_name, kernel in self.kernel_mapping.items():
209+
# Infer Mode: use Mode.TRAINING if model is training, else use Mode.INFERENCE
210+
mode = Mode.TRAINING if model.training else Mode.INFERENCE
211+
if compile:
212+
mode = mode | Mode.TORCH_COMPILE
213+
214+
if isinstance(kernel, str):
215+
repo_name = kernel
216+
device = infer_device(model)
217+
add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping)
218+
elif isinstance(kernel, dict):
219+
for device, repo_name in kernel.items():
220+
add_to_mapping(layer_name, device, repo_name, mode, compatible_mapping)
221+
222+
self.kernel_mapping = compatible_mapping

0 commit comments

Comments
 (0)