Skip to content

Commit e50051c

Browse files
committed
Use a more robust implementation to keep ActivationType aligned across layers
Signed-off-by: Neta Zmora <96238833+nzmora-nvidia@users.noreply.github.com>
1 parent 20ad4c9 commit e50051c

File tree

3 files changed

+13
-11
lines changed

3 files changed

+13
-11
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ enum class ActivationType
2929
Swiglu,
3030
Geglu,
3131
SwigluBias,
32-
Identity,
3332
Relu2,
33+
Identity,
3434
InvalidType
3535
};
3636

tensorrt_llm/_torch/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ class ActivationType(IntEnum):
4040
Swiglu = 3
4141
Geglu = 4
4242
SwigluBias = 5
43-
Identity = 6
44-
Relu2 = 7
43+
Relu2 = 6
44+
Identity = 7
4545
InvalidType = 8
4646

4747

tensorrt_llm/layers/moe.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tensorrt as trt
2121
import torch
2222

23+
from tensorrt_llm._torch.utils import ActivationType
2324
from tensorrt_llm._utils import (get_init_params, str_dtype_to_torch,
2425
str_dtype_to_trt)
2526
from tensorrt_llm.layers.lora import LoraParams
@@ -49,14 +50,15 @@
4950

5051
activation_str_to_int_map = {
5152
# [WARNING] Keep the below in sync with cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
52-
"gelu": 0,
53-
"gelu_new": 0,
54-
"relu": 1,
55-
"silu": 2,
56-
"swiglu": 3,
57-
"geglu": 4,
58-
"swiglu_bias": 5,
59-
"identity": 6,
53+
"gelu": int(ActivationType.Gelu),
54+
"gelu_new": int(ActivationType.Gelu),
55+
"relu": int(ActivationType.Relu),
56+
"silu": int(ActivationType.Silu),
57+
"swiglu": int(ActivationType.Swiglu),
58+
"geglu": int(ActivationType.Geglu),
59+
"swiglu_bias": int(ActivationType.SwigluBias),
60+
"identity": int(ActivationType.Identity),
61+
"relu2": int(ActivationType.Relu2),
6062
}
6163

6264

0 commit comments

Comments
 (0)