Skip to content

Commit 6902ffa

Browse files
authored
remove triton_kernels dep with kernels instead (#39926)
* remove dep * style * rm import * fix * style * simplify * style
1 parent cb2e0df commit 6902ffa

File tree

6 files changed

+58
-41
lines changed

6 files changed

+58
-41
lines changed

src/transformers/integrations/mxfp4.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,21 @@
4949

5050
# Copied from GPT_OSS repo and vllm
5151
def quantize_to_mxfp4(w):
52-
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp
52+
downcast_to_mxfp = triton_kernels_hub.numerics_details.mxfp.downcast_to_mxfp
5353

5454
w, w_scale = downcast_to_mxfp(w.to(torch.bfloat16), torch.uint8, axis=1)
5555
w, w_scale = swizzle_mxfp4(w, w_scale)
5656
return w, w_scale
5757

5858

5959
def swizzle_mxfp4(w, w_scale):
60-
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
61-
from triton_kernels.tensor_details import layout
62-
from triton_kernels.tensor_details.layout import StridedLayout
60+
FP4, convert_layout, wrap_torch_tensor = (
61+
triton_kernels_hub.tensor.FP4,
62+
triton_kernels_hub.tensor.convert_layout,
63+
triton_kernels_hub.tensor.wrap_torch_tensor,
64+
)
65+
layout = triton_kernels_hub.tensor_details.layout
66+
StridedLayout = triton_kernels_hub.tensor_details.layout.StridedLayout
6367

6468
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
6569
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts)
@@ -173,8 +177,12 @@ def __init__(self, config):
173177
self.down_proj_precision_config = None
174178

175179
def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor:
176-
from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs
177-
from triton_kernels.swiglu import swiglu_fn
180+
FnSpecs, FusedActivation, matmul_ogs = (
181+
triton_kernels_hub.matmul_ogs.FnSpecs,
182+
triton_kernels_hub.matmul_ogs.FusedActivation,
183+
triton_kernels_hub.matmul_ogs.matmul_ogs,
184+
)
185+
swiglu_fn = triton_kernels_hub.swiglu.swiglu_fn
178186

179187
with torch.cuda.device(hidden_states.device):
180188
act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, None), 2)
@@ -211,7 +219,12 @@ def routing_torch_dist(
211219
):
212220
import os
213221

214-
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch
222+
GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch = (
223+
triton_kernels_hub.routing.GatherIndx,
224+
triton_kernels_hub.routing.RoutingData,
225+
triton_kernels_hub.routing.ScatterIndx,
226+
triton_kernels_hub.routing.compute_expt_data_torch,
227+
)
215228

216229
with torch.cuda.device(logits.device):
217230
world_size = torch.distributed.get_world_size()
@@ -274,7 +287,7 @@ def mlp_forward(self, hidden_states):
274287
if dist.is_available() and dist.is_initialized():
275288
routing = routing_torch_dist
276289
else:
277-
from triton_kernels.routing import routing
290+
routing = triton_kernels_hub.routing.routing
278291

279292
routing = routing
280293
batch_size = hidden_states.shape[0]
@@ -337,8 +350,11 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, **
337350

338351

339352
def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwargs):
340-
from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig
341-
353+
PrecisionConfig, FlexCtx, InFlexData = (
354+
triton_kernels_hub.matmul_ogs.PrecisionConfig,
355+
triton_kernels_hub.matmul_ogs.FlexCtx,
356+
triton_kernels_hub.matmul_ogs.InFlexData,
357+
)
342358
from ..integrations.tensor_parallel import shard_and_distribute_module
343359

344360
model = kwargs.get("model", None)
@@ -450,6 +466,11 @@ def replace_with_mxfp4_linear(
450466
):
451467
if quantization_config.dequantize:
452468
return model
469+
else:
470+
from kernels import get_kernel
471+
472+
global triton_kernels_hub
473+
triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
453474

454475
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
455476

src/transformers/quantizers/quantizer_mxfp4.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121

2222
from ..utils import (
2323
is_accelerate_available,
24+
is_kernels_available,
2425
is_torch_available,
2526
is_triton_available,
26-
is_triton_kernels_availalble,
2727
logging,
2828
)
2929
from .quantizers_utils import get_module_from_name
@@ -68,7 +68,7 @@ def validate_environment(self, *args, **kwargs):
6868

6969
compute_capability = torch.cuda.get_device_capability()
7070
gpu_is_supported = compute_capability >= (7, 5)
71-
kernels_available = is_triton_available("3.4.0") and is_triton_kernels_availalble()
71+
kernels_available = is_triton_available("3.4.0") and is_kernels_available()
7272

7373
if self.pre_quantized:
7474
# On unsupported GPUs or without kernels, we will dequantize the model to bf16
@@ -82,7 +82,7 @@ def validate_environment(self, *args, **kwargs):
8282

8383
if not kernels_available:
8484
logger.warning_once(
85-
"MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed, we will default to dequantizing the model to bf16"
85+
"MXFP4 quantization requires triton >= 3.4.0 and kernels installed, we will default to dequantizing the model to bf16"
8686
)
8787
self.quantization_config.dequantize = True
8888
return
@@ -95,6 +95,12 @@ def validate_environment(self, *args, **kwargs):
9595
# we can't quantize the model in this case so we raise an error
9696
raise ValueError("MXFP4 quantization requires triton >= 3.4.0 and triton_kernels installed")
9797

98+
if not self.pre_quantized:
99+
from kernels import get_kernel
100+
101+
global triton_kernels_hub
102+
triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
103+
98104
device_map = kwargs.get("device_map", None)
99105
if device_map is None:
100106
logger.warning_once(
@@ -160,13 +166,15 @@ def create_quantized_param(
160166
unexpected_keys: Optional[list[str]] = None,
161167
**kwargs,
162168
):
163-
if is_triton_kernels_availalble() and is_triton_available("3.4.0"):
164-
from triton_kernels.matmul_ogs import FlexCtx, InFlexData, PrecisionConfig
165-
166169
from ..integrations import Mxfp4GptOssExperts, dequantize, load_and_swizzle_mxfp4, quantize_to_mxfp4
167170
from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
168171

169172
if not self.pre_quantized:
173+
PrecisionConfig, FlexCtx, InFlexData = (
174+
triton_kernels_hub.matmul_ogs.PrecisionConfig,
175+
triton_kernels_hub.matmul_ogs.FlexCtx,
176+
triton_kernels_hub.matmul_ogs.InFlexData,
177+
)
170178
module, _ = get_module_from_name(model, param_name)
171179
with torch.cuda.device(target_device):
172180
if isinstance(module, Mxfp4GptOssExperts):

src/transformers/testing_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@
170170
is_torchdynamo_available,
171171
is_torchvision_available,
172172
is_triton_available,
173-
is_triton_kernels_availalble,
174173
is_vision_available,
175174
is_vptq_available,
176175
strtobool,
@@ -471,13 +470,6 @@ def decorator(test_case):
471470
return decorator
472471

473472

474-
def require_triton_kernels(test_case):
475-
"""
476-
Decorator marking a test that requires triton_kernels. These tests are skipped when triton_kernels isn't installed.
477-
"""
478-
return unittest.skipUnless(is_triton_kernels_availalble(), "test requires triton_kernels")(test_case)
479-
480-
481473
def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION):
482474
"""
483475
Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed.

src/transformers/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@
270270
is_torchvision_v2_available,
271271
is_training_run_on_sagemaker,
272272
is_triton_available,
273-
is_triton_kernels_availalble,
274273
is_uroman_available,
275274
is_vision_available,
276275
is_vptq_available,

src/transformers/utils/import_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
238238
_matplotlib_available = _is_package_available("matplotlib")
239239
_mistral_common_available = _is_package_available("mistral_common")
240240
_triton_available, _triton_version = _is_package_available("triton", return_version=True)
241-
_triton_kernels_available = _is_package_available("triton_kernels")
242241

243242
_torch_version = "N/A"
244243
_torch_available = False
@@ -423,10 +422,6 @@ def is_triton_available(min_version: str = TRITON_MIN_VERSION):
423422
return _triton_available and version.parse(_triton_version) >= version.parse(min_version)
424423

425424

426-
def is_triton_kernels_availalble():
427-
return _triton_kernels_available
428-
429-
430425
def is_hadamard_available():
431426
return _hadamard_available
432427

tests/quantization/mxfp4/test_mxfp4.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818

1919
from transformers import AutoTokenizer, GptOssForCausalLM, Mxfp4Config
2020
from transformers.testing_utils import (
21+
require_kernels,
2122
require_torch,
2223
require_torch_gpu,
2324
require_torch_large_gpu,
2425
require_triton,
25-
require_triton_kernels,
2626
slow,
2727
)
2828
from transformers.utils import (
@@ -194,7 +194,7 @@ def test_quantizer_validation_missing_triton(self):
194194
"""Test quantizer validation when triton is not available"""
195195
with (
196196
patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False),
197-
patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False),
197+
patch("transformers.quantizers.quantizer_mxfp4.is_kernels_availalble", return_value=False),
198198
):
199199
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
200200

@@ -208,7 +208,7 @@ def test_quantizer_validation_missing_triton_pre_quantized_no_dequantize(self):
208208
"""Test quantizer validation when triton is not available but model is pre-quantized and dequantize is False"""
209209
with (
210210
patch("transformers.quantizers.quantizer_mxfp4.is_triton_available", return_value=False),
211-
patch("transformers.quantizers.quantizer_mxfp4.is_triton_kernels_availalble", return_value=False),
211+
patch("transformers.quantizers.quantizer_mxfp4.is_kernels_availalble", return_value=False),
212212
):
213213
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
214214

@@ -348,7 +348,7 @@ def test_convert_moe_packed_tensors(self):
348348
self.assertEqual(result.dtype, torch.bfloat16)
349349

350350
@require_triton(min_version="3.4.0")
351-
@require_triton_kernels
351+
@require_kernels
352352
@require_torch_gpu
353353
@require_torch
354354
def test_quantize_to_mxfp4(self):
@@ -368,12 +368,14 @@ def test_quantize_to_mxfp4(self):
368368

369369
@require_torch
370370
@require_torch_large_gpu
371+
@require_triton(min_version="3.4.0")
372+
@require_kernels
371373
@slow
372374
class Mxfp4ModelTest(unittest.TestCase):
373375
"""Test mxfp4 with actual models (requires specific model and hardware)"""
374376

375377
# These should be paths to real OpenAI MoE models for proper testing
376-
model_name_packed = "/fsx/mohamed/oai-hf/tests/20b_converted_packed" # TODO: Use real packed quantized model
378+
model_name = "openai/gpt-oss-20b"
377379

378380
input_text = "Once upon a time"
379381

@@ -421,12 +423,12 @@ def test_gpt_oss_model_loading_quantized_with_device_map(self):
421423
self.assertFalse(quantization_config.dequantize)
422424

423425
model = GptOssForCausalLM.from_pretrained(
424-
self.model_name_packed,
426+
self.model_name,
425427
quantization_config=quantization_config,
426428
torch_dtype=torch.bfloat16,
427429
device_map="auto",
428430
)
429-
tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed)
431+
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
430432
self.check_inference_correctness_quantized(model, tokenizer)
431433

432434
def test_gpt_oss_model_loading_dequantized_with_device_map(self):
@@ -438,12 +440,12 @@ def test_gpt_oss_model_loading_dequantized_with_device_map(self):
438440
self.assertTrue(quantization_config.dequantize)
439441

440442
model = GptOssForCausalLM.from_pretrained(
441-
self.model_name_packed,
443+
self.model_name,
442444
quantization_config=quantization_config,
443445
torch_dtype=torch.bfloat16,
444446
device_map="auto",
445447
)
446-
tokenizer = AutoTokenizer.from_pretrained(self.model_name_packed)
448+
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
447449
self.check_inference_correctness_quantized(model, tokenizer)
448450

449451
def test_model_device_map_validation(self):
@@ -464,12 +466,12 @@ def test_memory_footprint_comparison(self):
464466
# Expected: quantized < dequantized < unquantized memory usage
465467
quantization_config = Mxfp4Config(dequantize=True)
466468
quantized_model = GptOssForCausalLM.from_pretrained(
467-
self.model_name_packed,
469+
self.model_name,
468470
torch_dtype=torch.bfloat16,
469471
device_map="auto",
470472
)
471473
dequantized_model = GptOssForCausalLM.from_pretrained(
472-
self.model_name_packed,
474+
self.model_name,
473475
torch_dtype=torch.bfloat16,
474476
device_map="auto",
475477
quantization_config=quantization_config,

0 commit comments

Comments
 (0)