Skip to content

Commit 21404e3

Browse files
StonesjtukaiyuCopilot
authored
[CUDA EP] Add hardswish op and add bf16 support for hardsigmoid (#25562)
### Description <!-- Describe your changes. --> Add HardSwish operator which is x*HardSigmoid(x) Add bf16 support for HardSigmoid ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> HardSwish is implemented as HardSidmoid + Add in CUDA EP currently. A fused HardSwish should take half the time of HardSigmoid + Add. --------- Co-authored-by: kaiyu <kaiyu@bytedance.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 75f8480 commit 21404e3

File tree

7 files changed

+95
-32
lines changed

7 files changed

+95
-32
lines changed

docs/OperatorKernels.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,10 @@ Do not modify directly.*
703703
|GreaterOrEqual|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T1** = tensor(bool)|
704704
|||[12, 15]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T1** = tensor(bool)|
705705
|GridSample|*in* X:**T1**<br> *in* grid:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
706-
|HardSigmoid|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
706+
|HardSigmoid|*in* X:**T**<br> *out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
707+
|||[6, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
708+
|HardSwish|*in* X:**T**<br> *out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
709+
|||[14, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
707710
|Identity|*in* input:**T**<br> *out* output:**T**<br><br>or<br><br>*in* input:**V**<br> *out* output:**V**|19+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
708711
|||[14, 18]|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
709712
|||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|

onnxruntime/core/providers/cuda/activation/activations.cc

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -64,29 +64,33 @@ namespace cuda {
6464
UNARY_ACTIVATION_OP_VERSIONED_TYPED(name, startver, endver, double) \
6565
UNARY_ACTIVATION_OP_VERSIONED_TYPED(name, startver, endver, BFloat16)
6666

67-
#define UNARY_ACTIVATION_OP_HFD(name, ver) \
68-
UNARY_ACTIVATION_OP_TYPED(name, ver, MLFloat16) \
69-
UNARY_ACTIVATION_OP_TYPED(name, ver, float) \
70-
UNARY_ACTIVATION_OP_TYPED(name, ver, double) \
67+
#define UNARY_ACTIVATION_OP_HFD_WITH_BF16(name, ver) \
68+
UNARY_ACTIVATION_OP_TYPED(name, ver, MLFloat16) \
69+
UNARY_ACTIVATION_OP_TYPED(name, ver, float) \
70+
UNARY_ACTIVATION_OP_TYPED(name, ver, double) \
7171
UNARY_ACTIVATION_OP_TYPED(name, ver, BFloat16)
7272

73-
UNARY_ACTIVATION_OP_HFD(Elu, 6);
74-
UNARY_ACTIVATION_OP_HFD(HardSigmoid, 6);
73+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(Elu, 6);
74+
UNARY_ACTIVATION_OP_VERSIONED_HFD(HardSigmoid, 6, 21);
7575
UNARY_ACTIVATION_OP_VERSIONED_HFD(LeakyRelu, 6, 15);
76-
UNARY_ACTIVATION_OP_HFD(Relu, 14);
76+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(Relu, 14);
7777
UNARY_ACTIVATION_OP_VERSIONED_HFD_WITH_BF16(Relu, 13, 13);
7878
UNARY_ACTIVATION_OP_VERSIONED_HFD(Relu, 6, 12);
79-
UNARY_ACTIVATION_OP_HFD(Selu, 6);
80-
UNARY_ACTIVATION_OP_HFD(Sigmoid, 13);
79+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(Selu, 6);
80+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(Sigmoid, 13);
8181
UNARY_ACTIVATION_OP_VERSIONED_HFD(Sigmoid, 6, 12);
82-
UNARY_ACTIVATION_OP_HFD(Softplus, 1);
83-
UNARY_ACTIVATION_OP_HFD(Softsign, 1);
84-
UNARY_ACTIVATION_OP_HFD(Tanh, 13);
82+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(Softplus, 1);
83+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(Softsign, 1);
84+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(Tanh, 13);
8585
UNARY_ACTIVATION_OP_VERSIONED_HFD(Tanh, 6, 12);
86-
UNARY_ACTIVATION_OP_HFD(ThresholdedRelu, 10);
86+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(ThresholdedRelu, 10);
8787

88+
UNARY_ACTIVATION_OP_VERSIONED_HFD(HardSwish, 14, 21);
8889
// Opset-16 adds BFloat16 to allowed types for the LeakyRelu operator
89-
UNARY_ACTIVATION_OP_HFD(LeakyRelu, 16);
90+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(LeakyRelu, 16);
91+
// Opset-22 adds BFloat16 to allowed types for the HardSigmoid / HardSwish operators
92+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(HardSigmoid, 22);
93+
UNARY_ACTIVATION_OP_HFD_WITH_BF16(HardSwish, 22);
9094

9195
} // namespace cuda
9296
} // namespace onnxruntime

onnxruntime/core/providers/cuda/activation/activations.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,5 +174,16 @@ class ThresholdedRelu final : public UnaryElementwise {
174174
float alpha_;
175175
};
176176

177+
template <typename T>
178+
class HardSwish final : public UnaryElementwise {
179+
public:
180+
HardSwish(const OpKernelInfo& info) : UnaryElementwise(info) {}
181+
182+
Status ComputeInternal(OpKernelContext* context) const override;
183+
184+
private:
185+
MAKE_FUNC_CTX_NULL()
186+
};
187+
177188
} // namespace cuda
178189
} // namespace onnxruntime

onnxruntime/core/providers/cuda/activation/activations_impl.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ struct OP_ThresholdedRelu : public CtxThresholdedRelu {
9393
}
9494
};
9595

96+
template <typename T>
97+
struct OP_HardSwish : public CtxHardSwish {
98+
__device__ __inline__ T operator()(const T& a) const {
99+
return a * (_Min(_Max(a / (T)6 + (T)0.5, (T)0), (T)1));
100+
}
101+
};
102+
96103
#define UNARY_ACTIVATION_IMPL(name) \
97104
UNARY_ACTIVATION_IMPL_DECLARATION(name) { \
98105
UnaryElementWiseImpl(stream, \

onnxruntime/core/providers/cuda/activation/activations_impl.h

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,20 @@ typedef CtxNull CtxSoftplus;
3232
typedef CtxNull CtxSoftsign;
3333
typedef CtxNull CtxTanh;
3434
typedef CtxAlpha CtxThresholdedRelu;
35+
typedef CtxNull CtxHardSwish;
3536

36-
#define UNARY_ACTIVATION_OPS() \
37-
UNARY_ACTIVATION_OP_NAME(Elu) \
38-
UNARY_ACTIVATION_OP_NAME(HardSigmoid) \
39-
UNARY_ACTIVATION_OP_NAME(LeakyRelu) \
40-
UNARY_ACTIVATION_OP_NAME(Relu) \
41-
UNARY_ACTIVATION_OP_NAME(Selu) \
42-
UNARY_ACTIVATION_OP_NAME(Sigmoid) \
43-
UNARY_ACTIVATION_OP_NAME(Softplus) \
44-
UNARY_ACTIVATION_OP_NAME(Softsign) \
45-
UNARY_ACTIVATION_OP_NAME(Tanh) \
46-
UNARY_ACTIVATION_OP_NAME(ThresholdedRelu)
37+
#define UNARY_ACTIVATION_OPS() \
38+
UNARY_ACTIVATION_OP_NAME(Elu) \
39+
UNARY_ACTIVATION_OP_NAME(HardSigmoid) \
40+
UNARY_ACTIVATION_OP_NAME(LeakyRelu) \
41+
UNARY_ACTIVATION_OP_NAME(Relu) \
42+
UNARY_ACTIVATION_OP_NAME(Selu) \
43+
UNARY_ACTIVATION_OP_NAME(Sigmoid) \
44+
UNARY_ACTIVATION_OP_NAME(Softplus) \
45+
UNARY_ACTIVATION_OP_NAME(Softsign) \
46+
UNARY_ACTIVATION_OP_NAME(Tanh) \
47+
UNARY_ACTIVATION_OP_NAME(ThresholdedRelu) \
48+
UNARY_ACTIVATION_OP_NAME(HardSwish)
4749

4850
#define UNARY_ACTIVATION_IMPL_DECLARATION(name) \
4951
template <typename T> \

onnxruntime/core/providers/cuda/cuda_execution_provider.cc

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -543,9 +543,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
543543
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Elu);
544544
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Elu);
545545
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Elu);
546-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, HardSigmoid);
547-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, HardSigmoid);
548-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, HardSigmoid);
546+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 21, float, HardSigmoid);
547+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 21, double, HardSigmoid);
548+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 21, MLFloat16, HardSigmoid);
549549
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 15, float, LeakyRelu);
550550
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 15, double, LeakyRelu);
551551
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 15, MLFloat16, LeakyRelu);
@@ -1327,6 +1327,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
13271327
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Mul);
13281328
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Div);
13291329
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Relu);
1330+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, float, HardSwish);
1331+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, double, HardSwish);
1332+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, MLFloat16, HardSwish);
13301333

13311334
// OpSet 15
13321335
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, Pow);
@@ -1485,6 +1488,16 @@ class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom
14851488
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, MLFloat16, QuantizeLinear);
14861489
#endif
14871490

1491+
// Opset 22.
1492+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, HardSigmoid);
1493+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSigmoid);
1494+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSigmoid);
1495+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSigmoid);
1496+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, HardSwish);
1497+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSwish);
1498+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSwish);
1499+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish);
1500+
14881501
// Opset 23.
14891502
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization);
14901503
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization);
@@ -1539,9 +1552,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
15391552
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, Elu)>,
15401553
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, Elu)>,
15411554
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, Elu)>,
1542-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, float, HardSigmoid)>,
1543-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, double, HardSigmoid)>,
1544-
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, MLFloat16, HardSigmoid)>,
1555+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 21, float, HardSigmoid)>,
1556+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 21, double, HardSigmoid)>,
1557+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 21, MLFloat16, HardSigmoid)>,
15451558
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 15, float, LeakyRelu)>,
15461559
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 15, double, LeakyRelu)>,
15471560
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 6, 15, MLFloat16, LeakyRelu)>,
@@ -2315,6 +2328,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
23152328
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Div)>,
23162329
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Relu)>,
23172330
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, Trilu)>,
2331+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, float, HardSwish)>,
2332+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, double, HardSwish)>,
2333+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 21, MLFloat16, HardSwish)>,
23182334

23192335
// OpSet 15
23202336
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, Pow)>,
@@ -2479,6 +2495,15 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
24792495
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, MLFloat16, QuantizeLinear)>,
24802496
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, MLFloat16, QuantizeLinear)>,
24812497
#endif
2498+
// Opset 22
2499+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, HardSigmoid)>,
2500+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSigmoid)>,
2501+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSigmoid)>,
2502+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSigmoid)>,
2503+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, HardSwish)>,
2504+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSwish)>,
2505+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSwish)>,
2506+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish)>,
24822507
// Opset 23
24832508
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float_float, RMSNormalization)>,
24842509
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, double_double, RMSNormalization)>,

onnxruntime/test/providers/cpu/activation/activation_op_test.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,17 @@ TEST_F(ActivationOpTest, HardSigmoid) {
121121
{{"alpha", alpha}, {"beta", beta}});
122122
}
123123

124+
#if defined(USE_CUDA)
125+
TEST_F(ActivationOpTest, HardSwish) {
126+
TestActivationOp<float>("HardSwish", input_values, [](float x) { return x * std::max(std::min(x / 6.0f + 0.5f, 1.0f), 0.0f); }, {}, {},
127+
/*is_tensorrt_supported=*/false,
128+
/*opset_version= */ 14);
129+
TestActivationOp<double>("HardSwish", input_values_double, [](double x) { return x * std::max(std::min(x / 6.0 + 0.5, 1.0), 0.0); }, {}, {},
130+
/*is_tensorrt_supported=*/false,
131+
/*opset_version= */ 14);
132+
}
133+
#endif // USE_CUDA
134+
124135
TEST_F(ActivationOpTest, Tanh) {
125136
TestActivationOp<float>("Tanh",
126137
input_values,

0 commit comments

Comments
 (0)