Skip to content

Commit a163eee

Browse files
authored
Merge pull request #73 from InfiniTensor/dev-hardswish
feat: add hardswish cpu/cuda kernel
2 parents 9e4789c + afdfe93 commit a163eee

File tree

14 files changed

+103
-20
lines changed

14 files changed

+103
-20
lines changed

src/04kernel/include/kernel/collectors/simple_unary.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ namespace refactor::kernel {
2525
Erf,
2626
Neg,
2727
Not,
28+
HardSwish,
2829
};
2930

3031
std::string_view unaryName(SimpleUnaryType type);

src/04kernel/src/collectors/simple_unary.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ namespace refactor::kernel {
3131
CASE(Erf);
3232
CASE(Neg);
3333
CASE(Not);
34+
CASE(HardSwish);
3435
default:
3536
UNREACHABLE();
3637
}

src/04kernel/src/kernels/simple_binary/cpu_kernel.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ namespace refactor::kernel {
137137
switch (dataType.internal) {
138138
CASE_DT(std::fmod(a, b), F32);
139139
CASE_DT(a % b, U8);
140-
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I8);
140+
CASE_DT(static_cast<int8_t>(std::fmod(a, b)), I8);
141141
CASE_DT(a % b, U16);
142-
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I16);
143-
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I32);
144-
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I64);
142+
CASE_DT(static_cast<int16_t>(std::fmod(a, b)), I16);
143+
CASE_DT(static_cast<int32_t>(std::fmod(a, b)), I32);
144+
CASE_DT(static_cast<int64_t>(std::fmod(a, b)), I64);
145145
CASE_DT(std::fmod(a, b), F64);
146146
CASE_DT(a % b, U32);
147147
CASE_DT(a % b, U64);

src/04kernel/src/kernels/simple_binary/cuda_kernel.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,18 @@ extern "C" __global__ void kernel(
158158
case SimpleBinaryType::Fmod:
159159
switch (dt) {
160160
case DataType::U8:
161-
case DataType::I8:
162161
case DataType::U16:
162+
case DataType::U32:
163+
case DataType::U64:
164+
return "a % b";
165+
case DataType::I8:
166+
return "static_cast<char>(fmodf(a, b))";
163167
case DataType::I16:
168+
return "static_cast<short>(fmodf(a, b))";
164169
case DataType::I32:
170+
return "static_cast<int>(fmodf(a, b))";
165171
case DataType::I64:
166-
case DataType::U32:
167-
case DataType::U64:
168-
return "a % b < 0 ? (a % b + b) : (a % b)";
172+
return "static_cast<long long>(fmodf(a, b))";
169173
case DataType::F32:
170174
return "fmodf(a, b)";
171175
case DataType::FP16:

src/04kernel/src/kernels/simple_unary/cpu_kernel.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ namespace refactor::kernel {
1919
Op::Tanh,
2020
Op::Neg,
2121
Op::Erf,
22+
Op::HardSwish,
2223
};
2324
return supportedOp.contains(op) && a.dataType.isCpuNumberic()
2425
? std::make_unique<K>(op, a.dataType, a.elementsSize())
@@ -49,6 +50,12 @@ namespace refactor::kernel {
4950
using M = std::conditional_t<sizeof(T) <= 4, float, double>;
5051
return static_cast<T>(std::tanh(static_cast<M>(x)));
5152
}
53+
template<class T> auto hardswishFun(T x) noexcept -> T {
54+
auto mid = x / 6.f + .5f;
55+
return (mid <= 0) ? 0
56+
: (1 <= mid) ? x
57+
: x * mid;
58+
}
5259
auto copyForUnsigned(size_t n) noexcept -> Routine {
5360
return [n](runtime::Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
5461
std::memcpy(outputs[0], inputs[0], n);
@@ -171,6 +178,13 @@ namespace refactor::kernel {
171178
default:
172179
UNREACHABLE();
173180
}
181+
case Op::HardSwish:
182+
switch (dataType) {
183+
CASE(hardswishFun, F32);
184+
CASE(hardswishFun, F64);
185+
default:
186+
UNREACHABLE();
187+
}
174188
default:
175189
UNREACHABLE();
176190
}

src/04kernel/src/kernels/simple_unary/cuda_kernel.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace refactor::kernel {
1919
static const std::unordered_set<Op>
2020
supportedOp{Op::Abs, Op::Relu, Op::Sqrt,
2121
Op::Sigmoid, Op::Tanh, Op::Neg,
22-
Op::Erf};
22+
Op::Erf, Op::HardSwish};
2323
#ifndef USE_CUDA
2424
return nullptr;
2525
#endif
@@ -154,6 +154,11 @@ extern "C" __global__ void kernel(
154154
{__(Op::Erf, DT::I64 ), "erf(static_cast<double>(x))"},
155155
{__(Op::Erf, DT::FP16), "__float2half(erff(__half2float(x)))"},
156156
{__(Op::Erf, DT::BF16), "__float2bfloat16(erff(__bfloat162float(x)))"},
157+
158+
{__(Op::HardSwish, DT::F32 ), "x * fmaxf(0.f, fminf(1.f, fmaf(1.f/6.f, x, 0.5f)))"},
159+
{__(Op::HardSwish, DT::FP16), "x * __hmax(CUDART_ZERO_FP16, __hmin(CUDART_ONE_FP16, hrcp(__float2half(6.f)) * x + hrcp(__float2half(2.f))))"},
160+
{__(Op::HardSwish, DT::F64 ), "x * fmax(0.0, fmin(1.0, fma(1.0/6.0, x, 0.5)))"},
161+
157162
};
158163
// clang-format on
159164

src/04kernel/test/kernels/simple_binary/test_binary_cpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ TEST(kernel, BinaryCpu) {
8080
testBinaryCPU(SimpleBinaryType::Mul, [](float a, float b) { return a * b; });
8181
testBinaryCPU(SimpleBinaryType::Div, [](float a, float b) { return a / b; });
8282
testModCPU(SimpleBinaryType::Mod, [](int a, int b) { return a % b; });
83-
testFmodWithI32CPU(SimpleBinaryType::Fmod, [](int a, int b) { return a % b < 0 ? (a % b + b) : (a % b); });
83+
testFmodWithI32CPU(SimpleBinaryType::Fmod, [](int a, int b) { return static_cast<int32_t>(std::fmod(a, b)); });
8484
testBinaryCPU(SimpleBinaryType::Fmod, [](float a, float b) { return std::fmod(a, b); });
8585
}
8686

src/04kernel/test/kernels/simple_unary/test_cpu.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
using namespace refactor;
55
using namespace kernel;
66

7+
using VecFloat = std::vector<float>;
8+
79
static void testOp(SimpleUnaryType opType, float check(float)) {
810
// build routine
911
auto dataTensor = Tensor::share(DataType::F32, Shape{20, 30, 50});
@@ -12,7 +14,7 @@ static void testOp(SimpleUnaryType opType, float check(float)) {
1214
auto res = runtime::Resources();
1315
auto routine = kernel->lower(res).routine;
1416
// put input data
15-
std::vector<float> data(dataTensor->elementsSize());
17+
VecFloat data(dataTensor->elementsSize());
1618
for (auto i : range0_(data.size())) { data[i] = i * 1e-4f; }
1719
auto result = data;
1820
// inference
@@ -27,9 +29,34 @@ static void testOp(SimpleUnaryType opType, float check(float)) {
2729
}
2830
}
2931

32+
static void testOpWithData(SimpleUnaryType opType, const VecFloat &data) {
33+
// build routine
34+
auto dataTensor = Tensor::share(DataType::F32, Shape{2, 3});
35+
auto kernel = SimpleUnaryCpu::build(opType, *dataTensor);
36+
ASSERT_TRUE(kernel);
37+
auto res = runtime::Resources();
38+
auto routine = kernel->lower(res).routine;
39+
// put input data
40+
VecFloat inputdata(dataTensor->elementsSize());
41+
for (auto i : range0_(inputdata.size())) { inputdata[i] = i; }
42+
auto result = inputdata;
43+
// inference
44+
{
45+
void const *inputs[]{result.data()};
46+
void *outputs[]{result.data()};
47+
routine(res, nullptr, inputs, outputs);
48+
}
49+
// check
50+
for (auto i : range0_(inputdata.size())) {
51+
EXPECT_NEAR(data[i], result[i], 1e-5);
52+
}
53+
}
54+
3055
TEST(kernel, SimpleUnaryCpu) {
3156
testOp(SimpleUnaryType::Abs, std::abs);
3257
testOp(SimpleUnaryType::Sqrt, std::sqrt);
3358
testOp(SimpleUnaryType::Tanh, std::tanh);
3459
testOp(SimpleUnaryType::Erf, std::erf);
60+
testOpWithData(SimpleUnaryType::HardSwish,
61+
VecFloat{0.000000, 0.666667, 1.666667, 3.000000, 4.000000, 5.000000});
3562
}

src/04kernel/test/kernels/simple_unary/test_cuda.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ TEST(kernel, SimpleUnaryCuda) {
5252
testOp(SimpleUnaryType::Sigmoid);
5353
testOp(SimpleUnaryType::Tanh);
5454
testOp(SimpleUnaryType::Erf);
55+
testOp(SimpleUnaryType::HardSwish);
5556
}
5657

5758
#endif

src/05computation/src/operators/simple_unary.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ namespace refactor::computation {
8181
static uint8_t ID = 19;
8282
return reinterpret_cast<size_t>(&ID);
8383
}
84+
case SimpleUnaryType::HardSwish: {
85+
static uint8_t ID = 20;
86+
return reinterpret_cast<size_t>(&ID);
87+
}
8488
default:
8589
UNREACHABLE();
8690
}
@@ -128,6 +132,8 @@ namespace refactor::computation {
128132
return "Neg";
129133
case SimpleUnaryType::Not:
130134
return "Not";
135+
case SimpleUnaryType::HardSwish:
136+
return "HardSwish";
131137
default:
132138
UNREACHABLE();
133139
}

0 commit comments

Comments
 (0)