File tree Expand file tree Collapse file tree 1 file changed +9
-10
lines changed Expand file tree Collapse file tree 1 file changed +9
-10
lines changed Original file line number Diff line number Diff line change 2424if is_torch_npu_available ():
2525 import torch_npu
2626
27- ACTIVATION_FUNCTIONS = {
28- "swish" : nn .SiLU () ,
29- "silu" : nn .SiLU () ,
30- "mish" : nn .Mish () ,
31- "gelu" : nn .GELU () ,
32- "relu" : nn .ReLU () ,
27+ ACT2CLS = {
28+ "swish" : nn .SiLU ,
29+ "silu" : nn .SiLU ,
30+ "mish" : nn .Mish ,
31+ "gelu" : nn .GELU ,
32+ "relu" : nn .ReLU ,
3333}
3434
3535
@@ -44,11 +44,10 @@ def get_activation(act_fn: str) -> nn.Module:
4444 """
4545
4646 act_fn = act_fn .lower ()
47- if act_fn in ACTIVATION_FUNCTIONS :
48- return ACTIVATION_FUNCTIONS [act_fn ]
47+ if act_fn in ACT2CLS :
48+ return ACT2CLS [act_fn ]()
4949 else :
50- raise ValueError (f"Unsupported activation function: { act_fn } " )
51-
50+ raise ValueError (f"activation function { act_fn } not found in ACT2FN mapping { list (ACT2CLS .keys ())} " )
5251
5352class FP32SiLU (nn .Module ):
5453 r"""
You can’t perform that action at this time.
0 commit comments