Skip to content

Commit cda13fa

Browse files
[cherry-pick] [metal] Add conv-hardswish fusion (#7699)
* add_hardswish_fusion * add_hardswish_fusion
1 parent 43a8cef commit cda13fa

File tree

4 files changed

+22
-4
lines changed

4 files changed

+22
-4
lines changed

lite/backends/metal/metal_kernel/texture/Common.metal

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ enum ActivationType : ushort {
5656
PRELU = 3,
5757
LEAKY_RELU = 4,
5858
HARD_SIGMOID = 5,
59+
HARD_SWISH = 10,
5960
};
6061

6162
struct DropoutParam {
@@ -67,7 +68,8 @@ struct MetalActivationParam {
6768
float threshold; // RELU6
6869
float alpha; // LEAKY_RELU
6970
float offset; // HARD_SIGMOID
70-
float slope;
71+
float slope; // HARD_SIGMOID
72+
float scale; // HARD_SWISH
7173
};
7274

7375
struct ElementwiseAddParam {
@@ -315,6 +317,9 @@ inline ftype4 activation(const ftype4 input, constant MetalActivationParam& para
315317
return fmax(input, ftype(param.alpha) * input);
316318
case HARD_SIGMOID:
317319
return fmax(0.0, fmin(1.0, ftype(param.slope) * input + ftype(param.offset)));
320+
case HARD_SWISH:
321+
return (fmin(ftype(param.threshold), fmax(0.0, input + ftype(param.offset)))) * input /
322+
param.scale;
318323
}
319324
}
320325

lite/core/optimizer/mir/fusion/conv_activation_fuse_pass.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
7878
act_types.push_back("relu");
7979
act_types.push_back("relu6");
8080
act_types.push_back("hard_sigmoid");
81+
act_types.push_back("hard_swish");
8182
act_types.push_back("prelu");
8283
act_types.push_back("leaky_relu");
8384
}

lite/kernels/metal/image_op/conv2d_image_compute.mm

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@
130130
}
131131
}
132132

133-
// MPS don't support LeakyRelu
133+
// MPS don't support LeakyRelu and HardSwish
134134
switch (param.activation_param.active_type) {
135135
case lite_api::ActivationType::kIndentity:
136136
case lite_api::ActivationType::kRelu:
@@ -139,6 +139,9 @@
139139
break;
140140
case lite_api::ActivationType::kHardSigmoid:
141141
break;
142+
case lite_api::ActivationType::kHardSwish:
143+
should_use_mps = NO;
144+
break;
142145
case lite_api::ActivationType::kPRelu:
143146
break;
144147
case lite_api::ActivationType::kLeakyRelu:
@@ -346,11 +349,14 @@
346349
case lite_api::ActivationType::kLeakyRelu: {
347350
activate_type = (uint16_t)param.activation_param.active_type;
348351
} break;
352+
case lite_api::ActivationType::kHardSwish: {
353+
activate_type = (uint16_t)param.activation_param.active_type;
354+
} break;
349355
default: { LOG(FATAL) << "Conv2d: cannot support the activate type"; } break;
350356
}
351357
}
352358
// relu
353-
ActivationMetalParam activation_params{(unsigned short)activate_type, 0.0, 0.0, 0.0, 0.0};
359+
ActivationMetalParam activation_params{(unsigned short)activate_type, 0.0, 0.0, 0.0, 0.0, 0.0};
354360
switch (param.activation_param.active_type) {
355361
case lite_api::ActivationType::kIndentity:
356362
case lite_api::ActivationType::kRelu:
@@ -361,6 +367,11 @@
361367
case lite_api::ActivationType::kLeakyRelu: {
362368
activation_params.alpha = param.activation_param.Leaky_relu_alpha;
363369
} break;
370+
case lite_api::ActivationType::kHardSwish: {
371+
activation_params.threshold = param.activation_param.hard_swish_threshold;
372+
activation_params.offset = param.activation_param.hard_swish_offset;
373+
activation_params.scale = param.activation_param.hard_swish_scale;
374+
} break;
364375
default:
365376
break;
366377
}

lite/kernels/metal/image_op/metal_params.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ struct ActivationMetalParam {
3333
float threshold; // RELU6
3434
float alpha; // LEAKY_RELU
3535
float offset; // HARD_SIGMOID
36-
float slope;
36+
float slope; // HARD_SIGMOID
37+
float scale; // HARD_SWISH
3738
};
3839

3940
struct MetalConvParam {

0 commit comments

Comments
 (0)