Skip to content

Commit 4db2f36

Browse files
StonesjtukaiyuCopilot
authored andcommitted
[CUDA EP] Add hardswish op and add bf16 support for hardsigmoid (#25562)
### Description <!-- Describe your changes. --> Add HardSwish operator which is x*HardSigmoid(x) Add bf16 support for HardSigmoid ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> HardSwish is implemented as HardSidmoid + Add in CUDA EP currently. A fused HardSwish should take half the time of HardSigmoid + Add. --------- Co-authored-by: kaiyu <kaiyu@bytedance.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent e7d04b9 commit 4db2f36

File tree

7 files changed

+95
-32
lines changed

7 files changed

+95
-32
lines changed

docs/OperatorKernels.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,10 @@ Do not modify directly.*
706706
|GreaterOrEqual|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T1** = tensor(bool)|
707707
|||[12, 15]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T1** = tensor(bool)|
708708
|GridSample|*in* X:**T1**<br> *in* grid:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
709-
|HardSigmoid|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
709+
|HardSigmoid|*in* X:**T**<br> *out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
710+
|||[6, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
711+
|HardSwish|*in* X:**T**<br> *out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
712+
|||[14, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
710713
|Identity|*in* input:**T**<br> *out* output:**T**<br><br>or<br><br>*in* input:**V**<br> *out* output:**V**|19+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
711714
|||[14, 18]|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
712715
|||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|

onnxruntime/core/providers/cuda/activation/activations.cc

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,29 +64,33 @@ namespace cuda {
6464
UNARY_ACTIVATION_OP_VERSIONED_TYPED(name, startver, endver, double) \
6565
UNARY_ACTIVATION_OP_VERSIONED_TYPED(name, startver, endver, BFloat16)
6666

67-
#define UNARY_ACTIVATION_OP_HFD(name, ver) \
68-
UNARY_ACTIVATION_OP_TYPED(name, ver, MLFloat16) \
69-
UNARY_ACTIVATION_OP_TYPED(name, ver, float) \
70-
UNARY_ACTIVATION_OP_TYPED(name, ver, double) \
67+
#define UNARY_ACTIVATION_OP_HFD_WITH_BF16(name, ver) \
68+
UNARY_ACTIVATION_OP_TYPED(name, ver, MLFloat16) \
69+
UNARY_ACTIVATION_OP_TYPED(name, ver, float) \
70+
UNARY_ACTIVATION_OP_TYPED(name, ver, double) \
7171
UNARY_ACTIVATION_OP_TYPED(name, ver, BFloat16)
7272

73-
UNARY_ACTIVATION_OP_HFD(Elu, 6);
74-
UNARY_ACTIVATION_OP_HFD(HardSigmoid, 6);
73+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(Elu, 6);
74+
UNARY_ACTIVATION_OP_VERSIONED_HFD(HardSigmoid, 6, 21);
7575
UNARY_ACTIVATION_OP_VERSIONED_HFD(LeakyRelu, 6, 15);
76-
UNARY_ACTIVATION_OP_HFD(Relu, 14);
76+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(Relu, 14);
7777
UNARY_ACTIVATION_OP_VERSIONED_HFD_WITH_BF16(Relu, 13, 13);
7878
UNARY_ACTIVATION_OP_VERSIONED_HFD(Relu, 6, 12);
79-
UNARY_ACTIVATION_OP_HFD(Selu, 6);
80-
UNARY_ACTIVATION_OP_HFD(Sigmoid, 13);
79+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(Selu, 6);
80+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(Sigmoid, 13);
8181
UNARY_ACTIVATION_OP_VERSIONED_HFD(Sigmoid, 6, 12);
82-
UNARY_ACTIVATION_OP_HFD(Softplus, 1);
83-
UNARY_ACTIVATION_OP_HFD(Softsign, 1);
84-
UNARY_ACTIVATION_OP_HFD(Tanh, 13);
82+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(Softplus, 1);
83+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(Softsign, 1);
84+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(Tanh, 13);
8585
UNARY_ACTIVATION_OP_VERSIONED_HFD(Tanh, 6, 12);
86-
UNARY_ACTIVATION_OP_HFD(ThresholdedRelu, 10);
86+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(ThresholdedRelu, 10);
8787

88+
UNARY_ACTIVATION_OP_VERSIONED_HFD(HardSwish, 14, 21);
8889
// Opset-16 adds BFloat16 to allowed types for the LeakyRelu operator
89-
UNARY_ACTIVATION_OP_HFD(LeakyRelu, 16);
90+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(LeakyRelu, 16);
91+
// Opset-22 adds BFloat16 to allowed types for the HardSigmoid / HardSwish operators
92+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(HardSigmoid, 22);
93+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(HardSwish, 22);
9094

9195
} // namespace cuda
9296
} // namespace onnxruntime

onnxruntime/core/providers/cuda/activation/activations.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,5 +174,16 @@ class ThresholdedRelu final : public UnaryElementwise {
174174
float alpha_;
175175
};
176176

177+
template <typename T>
178+
class HardSwish final : public UnaryElementwise {
179+
public:
180+
HardSwish(const OpKernelInfo& info) : UnaryElementwise(info) {}
181+
182+
Status ComputeInternal(OpKernelContext* context) const override;
183+
184+
private:
185+
MAKE_FUNC_CTX_NULL()
186+
};
187+
177188
} // namespace cuda
178189
} // namespace onnxruntime

onnxruntime/core/providers/cuda/activation/activations_impl.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ struct OP_ThresholdedRelu : public CtxThresholdedRelu {
9393
}
9494
};
9595

96+
template <typename T>
97+
struct OP_HardSwish : public CtxHardSwish {
98+
__device__ __inline__ T operator()(const T& a) const {
99+
return a * (_Min(_Max(a / (T)6 + (T)0.5, (T)0), (T)1));
100+
}
101+
};
102+
96103
#define UNARY_ACTIVATION_IMPL(name) \
97104
UNARY_ACTIVATION_IMPL_DECLARATION(name) { \
98105
UnaryElementWiseImpl(stream, \

onnxruntime/core/providers/cuda/activation/activations_impl.h

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,20 @@ typedef CtxNull CtxSoftplus;
3232
typedef CtxNull CtxSoftsign;
3333
typedef CtxNull CtxTanh;
3434
typedef CtxAlpha CtxThresholdedRelu;
35+
typedef CtxNull CtxHardSwish;
3536

36-
#define UNARY_ACTIVATION_OPS() \
37-
UNARY_ACTIVATION_OP_NAME(Elu) \
38-
UNARY_ACTIVATION_OP_NAME(HardSigmoid) \
39-
UNARY_ACTIVATION_OP_NAME(LeakyRelu) \
40-
UNARY_ACTIVATION_OP_NAME(Relu) \
41-
UNARY_ACTIVATION_OP_NAME(Selu) \
42-
UNARY_ACTIVATION_OP_NAME(Sigmoid) \
43-
UNARY_ACTIVATION_OP_NAME(Softplus) \
44-
UNARY_ACTIVATION_OP_NAME(Softsign) \
45-
UNARY_ACTIVATION_OP_NAME(Tanh) \
46-
UNARY_ACTIVATION_OP_NAME(ThresholdedRelu)
37+
#define UNARY_ACTIVATION_OPS() \
38+
UNARY_ACTIVATION_OP_NAME(Elu) \
39+
UNARY_ACTIVATION_OP_NAME(HardSigmoid) \
40+
UNARY_ACTIVATION_OP_NAME(LeakyRelu) \
41+
UNARY_ACTIVATION_OP_NAME(Relu) \
42+
UNARY_ACTIVATION_OP_NAME(Selu) \
43+
UNARY_ACTIVATION_OP_NAME(Sigmoid) \
44+
UNARY_ACTIVATION_OP_NAME(Softplus) \
45+
UNARY_ACTIVATION_OP_NAME(Softsign) \
46+
UNARY_ACTIVATION_OP_NAME(Tanh) \
47+
UNARY_ACTIVATION_OP_NAME(ThresholdedRelu) \
48+
UNARY_ACTIVATION_OP_NAME(HardSwish)
4749

4850
#define UNARY_ACTIVATION_IMPL_DECLARATION(name) \
4951
template <typename T> \

onnxruntime/core/providers/cuda/cuda_execution_provider.cc

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -540,9 +540,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
540540
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Elu);
541541
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Elu);
542542
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Elu);
543-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, HardSigmoid);
544-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, HardSigmoid);
545-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, HardSigmoid);
543+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 21, float, HardSigmoid);
544+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 21, double, HardSigmoid);
545+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 21, MLFloat16, HardSigmoid);
546546
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 15, float, LeakyRelu);
547547
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 15, double, LeakyRelu);
548548
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 15, MLFloat16, LeakyRelu);
@@ -1324,6 +1324,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
13241324
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Mul);
13251325
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Div);
13261326
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Relu);
1327+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, float, HardSwish);
1328+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, double, HardSwish);
1329+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, MLFloat16, HardSwish);
13271330

13281331
// OpSet 15
13291332
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, Pow);
@@ -1481,6 +1484,16 @@ class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom
14811484
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, MLFloat16, QuantizeLinear);
14821485
#endif
14831486

1487+
// Opset 22.
1488+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, HardSigmoid);
1489+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSigmoid);
1490+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSigmoid);
1491+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSigmoid);
1492+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, HardSwish);
1493+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSwish);
1494+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSwish);
1495+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish);
1496+
14841497
// Opset 23.
14851498
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization);
14861499
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization);
@@ -1535,9 +1548,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
15351548
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Elu)>,
15361549
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Elu)>,
15371550
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Elu)>,
1538-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, HardSigmoid)>,
1539-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, HardSigmoid)>,
1540-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, HardSigmoid)>,
1551+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 21, float, HardSigmoid)>,
1552+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 21, double, HardSigmoid)>,
1553+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 21, MLFloat16, HardSigmoid)>,
15411554
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 15, float, LeakyRelu)>,
15421555
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 15, double, LeakyRelu)>,
15431556
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 15, MLFloat16, LeakyRelu)>,
@@ -2311,6 +2324,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
23112324
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Div)>,
23122325
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Relu)>,
23132326
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, Trilu)>,
2327+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, float, HardSwish)>,
2328+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, double, HardSwish)>,
2329+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, MLFloat16, HardSwish)>,
23142330

23152331
// OpSet 15
23162332
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, Pow)>,
@@ -2474,6 +2490,15 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
24742490
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, MLFloat16, QuantizeLinear)>,
24752491
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, MLFloat16, QuantizeLinear)>,
24762492
#endif
2493+
// Opset 22
2494+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, HardSigmoid)>,
2495+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSigmoid)>,
2496+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSigmoid)>,
2497+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSigmoid)>,
2498+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, HardSwish)>,
2499+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSwish)>,
2500+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSwish)>,
2501+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish)>,
24772502
// Opset 23
24782503
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization)>,
24792504
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization)>,

onnxruntime/test/providers/cpu/activation/activation_op_test.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,17 @@ TEST_F(ActivationOpTest, HardSigmoid) {
121121
{{"alpha", alpha}, {"beta", beta}});
122122
}
123123

124+
#if defined(USE_CUDA)
125+
TEST_F(ActivationOpTest, HardSwish) {
126+
TestActivationOp<float>("HardSwish", input_values, [](float x) { return x * std::max(std::min(x / 6.0f + 0.5f, 1.0f), 0.0f); }, {}, {},
127+
/*is_tensorrt_supported=*/false,
128+
/*opset_version= */ 14);
129+
TestActivationOp<double>("HardSwish", input_values_double, [](double x) { return x * std::max(std::min(x / 6.0 + 0.5, 1.0), 0.0); }, {}, {},
130+
/*is_tensorrt_supported=*/false,
131+
/*opset_version= */ 14);
132+
}
133+
#endif // USE_CUDA
134+
124135
TEST_F(ActivationOpTest, Tanh) {
125136
TestActivationOp<float>("Tanh",
126137
input_values,

0 commit comments

Comments
 (0)