Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
2a27823
Test working as I think it should work
vthumbe1503 Aug 26, 2025
d4c06c5
initial draft of changes to get GPT oss based swiglu integrated, gate…
vthumbe1503 Sep 5, 2025
1f596af
redundant implementation for the pytorch to te hook up, refactoring t…
vthumbe1503 Sep 6, 2025
42f85c3
all gated kernels modified, pytest working for oss swiglu
vthumbe1503 Sep 8, 2025
c9d3311
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2025
5d06c2a
fix the merge conflict
vthumbe1503 Sep 8, 2025
025ce6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2025
d964b24
accidentally had removed some activations, minor bug in the templated…
vthumbe1503 Sep 8, 2025
de9ef2f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2025
8e17473
parent de9ef2fe450daae0d4ea1b647a37219f72814f66
vthumbe1503 Sep 8, 2025
1f2c65b
accidentally removed the copyright
vthumbe1503 Sep 8, 2025
75c4b13
fix linting issue
vthumbe1503 Sep 8, 2025
288e926
minor issue in comments
vthumbe1503 Sep 8, 2025
448eceb
Commit is for another PR
vthumbe1503 Sep 10, 2025
23b5822
revert changes since this belongs to another PR
vthumbe1503 Sep 10, 2025
a1a5794
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2025
0d6a3ea
Revert change back since belongs to another PR
vthumbe1503 Sep 10, 2025
33c3364
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2025
a724c2d
Changes belong to another PR
vthumbe1503 Sep 10, 2025
34d9815
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2025
3475264
Revert changes here
vthumbe1503 Sep 10, 2025
5e687d1
address review comments
vthumbe1503 Sep 15, 2025
8535dfb
cleanup
vthumbe1503 Sep 15, 2025
fa0e9a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 15, 2025
aee3fb9
fix linting error
vthumbe1503 Sep 15, 2025
87ae3d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 15, 2025
3858eab
Address review comments, fix mxfp8 kernel bug: was not passing clampe…
vthumbe1503 Sep 18, 2025
de3080e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2025
7bf0bc4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2025
fe93c01
Use limit=0.75 in clamped SwiGLU test
timmoon10 Sep 19, 2025
5d3b169
Address review comments
vthumbe1503 Sep 19, 2025
439e5a1
Merge branch 'main' into users/vthumbe/gpt_oss_swiglu_integration
vthumbe1503 Sep 19, 2025
3a1cd45
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2025
3417290
accidentally removed a line while resolving merge conflict
vthumbe1503 Sep 19, 2025
724b37a
match pytorch implementation: dclamp should be 1 for borders of clamp…
vthumbe1503 Sep 19, 2025
994765c
fix dswiglu quantization fusion bug
vthumbe1503 Sep 20, 2025
bebd272
Merge branch 'main' into users/vthumbe/gpt_oss_swiglu_integration
vthumbe1503 Sep 20, 2025
eb7f240
pass param by reference as much as possible
vthumbe1503 Sep 20, 2025
4f44e59
Merge branch 'users/vthumbe/gpt_oss_swiglu_integration' of github.com…
vthumbe1503 Sep 20, 2025
20baa44
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2025
1132edf
float should rather be bool: fix by copilot
vthumbe1503 Sep 23, 2025
120339d
Merge branch 'main' into users/vthumbe/gpt_oss_swiglu_integration
vthumbe1503 Sep 29, 2025
ccccf88
Merge branch 'main' into users/vthumbe/gpt_oss_swiglu_integration
vthumbe1503 Sep 29, 2025
92d5b09
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2025
04510b7
{ missed in activation.cpp
vthumbe1503 Sep 30, 2025
442eac0
merge conflict
vthumbe1503 Sep 30, 2025
f1e3599
minor formatting change
vthumbe1503 Sep 30, 2025
b2dc251
nvfp4 change
vthumbe1503 Sep 30, 2025
0e8d139
Merge branch 'main' into users/vthumbe/gpt_oss_swiglu_integration
vthumbe1503 Sep 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,6 +1736,80 @@ def test_swiglu(
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantize_forward", (False, True))
@pytest.mark.parametrize("quantize_backward", (False, True))
def test_clamped_swiglu(
self,
*,
out_shape: Iterable[int] = (32, 32),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantize_forward: bool,
quantize_backward: bool,
limit: float = 0.75,
alpha: float = 1.702,
):
# Test SwiGLU variant used in GPT OSS.
# Tensor dimensions
in_shape = list(out_shape)
in_shape[-1] *= 2

# Skip invalid configurations
quantized_compute = quantization is not None
if not quantized_compute and (quantize_forward or quantize_backward):
pytest.skip("Quantization scheme has not been provided")
maybe_skip_quantization(quantization, dims=in_shape, device=device)

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
x_glu, x_linear = x_ref.chunk(2, dim=-1)
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
y_ref = out_glu * (x_linear + 1)
y_ref.backward(dy_ref)

# Implementation with fusible operation
recipe = make_recipe(quantization)

forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantize_backward),
te_ops.ClampedSwiGLU(limit=limit, alpha=alpha),
te_ops.Quantize(forward=quantize_forward, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = forward(x_test)

y_test.backward(dy_test)

# Expected numerical error
tols = dtype_tols(dtype)
if quantized_compute and quantization == "nvfp4":
tols = dtype_tols(tex.DType.kFloat4E2M1)
elif quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)

# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2)))
@pytest.mark.parametrize("dtype", _dtypes)
Expand Down
10 changes: 4 additions & 6 deletions transformer_engine/common/activation/activation_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,20 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
}

template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
void gated_act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = false;
constexpr NVTETensor grad = nullptr;

quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, stream);
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, p, stream);
}

template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
ComputeType (*DActOP)(ComputeType, const Param &)>
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p,
cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = true;

quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, stream);
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, p, stream);
}

} // namespace transformer_engine
Expand Down
12 changes: 8 additions & 4 deletions transformer_engine/common/activation/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
Empty e = {};
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, e, stream);
}

void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(grad, input, output, stream);
Empty e = {};
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(grad, input, output, e, stream);
}

void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
Expand All @@ -49,12 +51,14 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
Empty e = {};
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, e, stream);
}

void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(grad, input, output, stream);
Empty e = {};
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(grad, input, output, e, stream);
}
12 changes: 8 additions & 4 deletions transformer_engine/common/activation/relu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream);
Empty e = {};
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, e, stream);
}

void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(grad, input, output, stream);
Empty e = {};
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(grad, input, output, e, stream);
}

void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
Expand All @@ -49,12 +51,14 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream);
Empty e = {};
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, e, stream);
}

void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(grad, input, output, stream);
Empty e = {};
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(grad, input, output, e, stream);
}
23 changes: 21 additions & 2 deletions transformer_engine/common/activation/swiglu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,31 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream);
Empty e = {};
gated_act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, e, stream);
}

void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(grad, input, output, stream);
Empty e = {};
dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(grad, input, output, e, stream);
}

void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_swiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha};
gated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>>(input, output, param, stream);
}

void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_dswiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha};
dgated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>, clamped_dsilu<fp32, fp32>>(
grad, input, output, param, stream);
}
40 changes: 40 additions & 0 deletions transformer_engine/common/include/transformer_engine/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,26 @@ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);

/*! \brief Computes the gated Swish activation of the input used in GPT OSS.
*
* See https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This Gated activation has two differences compared to the original SwiGLU
* 1. Both gate and pre-activations are clipped based on parameter limit.
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
* by original GELU paper https://arxiv.org/pdf/1606.08415
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Act(input[N, :H]) x input[N, H:]
* \param[in] limit Clipping limits for gate and pre-activation.
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
cudaStream_t stream);

/*! \brief Computes the gated ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
Expand Down Expand Up @@ -230,6 +250,26 @@ void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);

/*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS.
*
* https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This activation has two differences compared to the original SwiGLU
* 1. Both gate and pre-activations are clipped based on parameter limit.
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
* by original GELU paper https://arxiv.org/pdf/1606.08415
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] limit Clipping limits for gate and pre-activation.
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, cudaStream_t stream);

/*! \brief Computes the gated ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
Expand Down
Loading
Loading