File tree Expand file tree Collapse file tree 3 files changed +13
-11
lines changed
cpp/tensorrt_llm/kernels/cutlass_kernels/include Expand file tree Collapse file tree 3 files changed +13
-11
lines changed Original file line number Diff line number Diff line change @@ -29,8 +29,8 @@ enum class ActivationType
2929 Swiglu,
3030 Geglu,
3131 SwigluBias,
32- Identity,
3332 Relu2,
33+ Identity,
3434 InvalidType
3535};
3636
Original file line number Diff line number Diff line change @@ -40,8 +40,8 @@ class ActivationType(IntEnum):
4040 Swiglu = 3
4141 Geglu = 4
4242 SwigluBias = 5
43- Identity = 6
44- Relu2 = 7
43+ Relu2 = 6
44+ Identity = 7
4545 InvalidType = 8
4646
4747
Original file line number Diff line number Diff line change 2020import tensorrt as trt
2121import torch
2222
23+ from tensorrt_llm ._torch .utils import ActivationType
2324from tensorrt_llm ._utils import (get_init_params , str_dtype_to_torch ,
2425 str_dtype_to_trt )
2526from tensorrt_llm .layers .lora import LoraParams
4950
5051activation_str_to_int_map = {
5152 # [WARNING] Keep the below in sync with cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
52- "gelu" : 0 ,
53- "gelu_new" : 0 ,
54- "relu" : 1 ,
55- "silu" : 2 ,
56- "swiglu" : 3 ,
57- "geglu" : 4 ,
58- "swiglu_bias" : 5 ,
59- "identity" : 6 ,
53+ "gelu" : int (ActivationType .Gelu ),
54+ "gelu_new" : int (ActivationType .Gelu ),
55+ "relu" : int (ActivationType .Relu ),
56+ "silu" : int (ActivationType .Silu ),
57+ "swiglu" : int (ActivationType .Swiglu ),
58+ "geglu" : int (ActivationType .Geglu ),
59+ "swiglu_bias" : int (ActivationType .SwigluBias ),
60+ "identity" : int (ActivationType .Identity ),
61+ "relu2" : int (ActivationType .Relu2 ),
6062}
6163
6264
You can’t perform that action at this time.
0 commit comments