Skip to content
Open
36 changes: 35 additions & 1 deletion ATTRIBUTIONS-Python.md
Original file line number Diff line number Diff line change
Expand Up @@ -62379,7 +62379,7 @@ Copyright 2018- The Hugging Face team. All rights reserved.
- `Homepage`: https://github.com/huggingface/transformers


## triton (3.5.0)
## triton (3.5.1)

### Licenses
License: `MIT License`
Expand Down Expand Up @@ -62417,6 +62417,40 @@ License: `MIT License`
- `Homepage`: https://github.com/triton-lang/triton/


## triton-kernels (3.5.1)

### Licenses
License: `MIT License`

- `LICENSE` (from triton repository root):
```
Copyright 2018-2020 Philippe Tillet
Copyright 2020-2022 OpenAI

Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files
(the "Software"), to deal in the Software without restriction,
including without limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of the Software,
and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:

The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
```

### URLs
- `Source`: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels


## tritonclient (2.63.0)

### Licenses
Expand Down
27 changes: 2 additions & 25 deletions examples/models/core/gpt_oss/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,33 +107,10 @@ Once again, the function call works successfully, this time using a different fu

## Using OpenAI Triton Kernels for MoE

OpenAI ships a set of Triton kernels optimized for its MoE models. TensorRT-LLM can leverage these kernels; enable them with the steps below:

1. **Build and install Triton** (tested with the commit below):
OpenAI ships a set of Triton kernels optimized for its MoE models.

```bash
git clone https://github.com/triton-lang/triton.git
cd triton
# Specific commit verified with TensorRT-LLM
git checkout f3067cd3bd0c29065fa4ecdb724b6f29cbabea5f
pip install -r python/requirements.txt # build-time dependencies
pip install wheel build
python3 setup.py bdist_wheel
pip install ./dist/*.whl
```

2. **Expose the Triton kernels to TensorRT-LLM**
The kernels are not packaged in the wheel, so set the environment variable `TRITON_ROOT` to your Triton clone:

```bash
export TRITON_ROOT=/local/user/triton
# TensorRT-LLM expects the kernels at:
# $TRITON_ROOT/python/triton_kernels
```

3. **Select Triton as the MoE backend**

• **trtllm-serve** (or other similar commands) — add this snippet to the YAML file passed via `--config`:
To use the Triton MoE backend with **trtllm-serve** (or other similar commands), add this snippet to the YAML file passed via `--config`:

```yaml
moe_config:
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ etcd3 @ git+https://github.com/kragniz/python-etcd3.git@e58a899579ba416449c4e225
blake3
soundfile
triton==3.5.1
# NOTE: the triton-kernels version should be aligned with the triton version above
triton-kernels @ git+https://github.com/triton-lang/[email protected]#subdirectory=python/triton_kernels
tiktoken
blobfile
openai-harmony==0.0.4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,15 @@

import torch
import torch.nn.functional as F

IS_TRITON_KERNELS_AVAILABLE = True
TRITON_KERNELS_UNAVAILABLE_REASON = ""

try:
from triton_kernels.matmul_ogs import (
FlexCtx,
FnSpecs,
FusedActivation,
PrecisionConfig,
matmul_ogs,
)
from triton_kernels.numerics import InFlexData
from triton_kernels.routing import RoutingData, routing
from triton_kernels.swiglu import swiglu_fn
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.tensor_details.layout import StridedLayout

from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import TritonEPRouter

except Exception as _e:
IS_TRITON_KERNELS_AVAILABLE = False
TRITON_KERNELS_UNAVAILABLE_REASON = f"{type(_e).__name__}: {_e}"

FlexCtx = FnSpecs = FusedActivation = PrecisionConfig = matmul_ogs = None
InFlexData = RoutingData = routing = swiglu_fn = None
FP4 = convert_layout = wrap_torch_tensor = None
layout = StridedLayout = None
TritonEPRouter = None
from triton_kernels.matmul_ogs import FlexCtx, FnSpecs, FusedActivation, PrecisionConfig, matmul_ogs
from triton_kernels.numerics import InFlexData
from triton_kernels.routing import RoutingData, routing
from triton_kernels.swiglu import swiglu_fn
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.tensor_details.layout import StridedLayout

from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import TritonEPRouter


# copied from transformers.integrations.mxfp4::swizzle_mxfp4 with minor modification
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
register_ad_pattern,
)

from ...custom_ops.fused_moe.mxfp4_moe import IS_TRITON_KERNELS_AVAILABLE
from ...utils.module import get_submodule_of_param
from ...utils.node_utils import is_op
from ..interface import BaseTransform, TransformInfo, TransformRegistry
Expand Down Expand Up @@ -220,11 +219,7 @@ def _apply(
shared_config,
) -> Tuple[GraphModule, TransformInfo]:
qcfg = factory.get_quant_config()
if (
not qcfg
or qcfg.get("quant_method", "") != self.algo_name
or not IS_TRITON_KERNELS_AVAILABLE
):
if not qcfg or qcfg.get("quant_method", "") != self.algo_name:
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
Expand Down
43 changes: 19 additions & 24 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,19 @@
from __future__ import annotations

import os
import sys
from typing import Dict, List, NamedTuple, Optional

import torch
import torch.nn as nn
import triton
import triton.language as tl

IS_TRITON_KERNELS_AVAILABLE = False
# We expect to find triton_kernels under $TRITON_ROOT/python/triton_kernels
# Triton upstream commit f3067cd3bd0c29065fa4ecdb724b6f29cbabea5f has been verified.
triton_root = os.getenv('TRITON_ROOT')
if triton_root:
triton_root = os.path.abspath(
os.path.join(triton_root, 'python', 'triton_kernels'))
if os.path.exists(triton_root) and triton_root not in sys.path:
sys.path.insert(0, triton_root)
assert triton.__version__ >= "3.4.0", "Triton kernels are detected but the Triton wheel is too old"
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import (FlexCtx, FnSpecs, FusedActivation,
PrecisionConfig, matmul_ogs)
from triton_kernels.numerics import InFlexData
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp_torch
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
IS_TRITON_KERNELS_AVAILABLE = True
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import (FlexCtx, FnSpecs, FusedActivation,
PrecisionConfig, matmul_ogs)
from triton_kernels.numerics import InFlexData
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp_torch
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout

from ...model_config import ModelConfig
from ..linear import TensorParallelMode, load_weight_shard
Expand Down Expand Up @@ -214,11 +201,16 @@ def create_weights(self, module: torch.nn.Module):
module.intermediate_size_per_partition,
module.hidden_size,
)
# Bias shapes use the output dimension (last dim) of the transposed weight shapes
w3_w1_bias_shape = (w3_w1_weight_shape[0], w3_w1_weight_shape[2])
w2_bias_shape = (w2_weight_shape[0], w2_weight_shape[2])
super().create_weights(module,
weight_dtype,
w3_w1_weight_shape,
w2_weight_shape,
bias_dtype=torch.float32)
bias_dtype=torch.float32,
w3_w1_bias_shape=w3_w1_bias_shape,
w2_bias_shape=w2_bias_shape)
self.setup_quant_scales(module)

def setup_quant_scales(self, module: torch.nn.Module):
Expand Down Expand Up @@ -404,12 +396,17 @@ def create_weights(self, module: torch.nn.Module):
module.intermediate_size_per_partition,
module.hidden_size,
)
# Bias shapes use the output dimension (last dim) of the transposed weight shapes
w3_w1_bias_shape = (w3_w1_weight_shape[0], w3_w1_weight_shape[2])
w2_bias_shape = (w2_weight_shape[0], w2_weight_shape[2])
FusedMoEMethodBase.create_weights(self,
module,
weight_dtype,
w3_w1_weight_shape,
w2_weight_shape,
bias_dtype=torch.float32)
bias_dtype=torch.float32,
w3_w1_bias_shape=w3_w1_bias_shape,
w2_bias_shape=w2_bias_shape)

fc31_dequant = nn.Parameter(torch.empty(
module.expert_size_per_partition, dtype=torch.float32),
Expand Down Expand Up @@ -1295,8 +1292,6 @@ def __init__(
weight_loading_mode=weight_loading_mode,
layer_idx=layer_idx,
)
if not IS_TRITON_KERNELS_AVAILABLE:
raise ImportError("Triton kernels are not available.")
if torch.cuda.get_device_capability()[0] != 9 and self.ep_size > 1:
raise NotImplementedError(
"TritonFusedMoE is only supported on Hopper with EP size > 1.")
Expand Down
14 changes: 3 additions & 11 deletions tensorrt_llm/_torch/modules/triton_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,15 @@

import torch
from torch.nn.parameter import Parameter
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, matmul_ogs
from triton_kernels.numerics import InFlexData

from tensorrt_llm._torch.peft.lora.layer import LoraLayer
from tensorrt_llm.mapping import Mapping

from ...models.modeling_utils import QuantConfig
# Reuse the common Triton import setup
from .fused_moe.fused_moe_triton import (IS_TRITON_KERNELS_AVAILABLE,
maybe_update_stride,
from .fused_moe.fused_moe_triton import (maybe_update_stride,
swizzle_weight_and_scale)

if IS_TRITON_KERNELS_AVAILABLE:
from triton_kernels.matmul_ogs import (FlexCtx, PrecisionConfig, matmul_ogs)
from triton_kernels.numerics import InFlexData

from .linear import (Linear, LinearMethodBase, TensorParallelMode,
WeightsLoadingConfig, copy_weight, load_weight_shard,
load_weights_fused_gate_up_helper,
Expand Down Expand Up @@ -383,9 +378,6 @@ def __init__(
use_custom_cublas_mm: bool = False,
lora: Optional[LoraLayer] = None,
):
if not IS_TRITON_KERNELS_AVAILABLE:
raise ImportError("Triton kernels are not available. "
"Please install the required dependencies.")
assert not use_custom_cublas_mm, "TritonLinear does not support custom cublas mm."

super().__init__(
Expand Down
4 changes: 0 additions & 4 deletions tests/integration/defs/accuracy/references/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,6 @@ microsoft/phi-4:
accuracy: 90.64
mistralai/Codestral-22B-v0.1:
- accuracy: 67.10
GPT-OSS/BF16:
- accuracy: 90.3
- kv_cache_quant_algo: FP8
accuracy: 90.3
GPT-OSS/120B-MXFP4:
- accuracy: 90.3
- spec_dec_algo: Eagle
Expand Down
Loading