Skip to content

Commit 1bec83f

Browse files
authored
disable_skip_layernorm_fp16 (#45041)
1 parent 9a04540 commit 1bec83f

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ namespace tensorrt {
2222
class SkipLayerNormOpConverter : public OpConverter {
2323
public:
2424
void operator()(const framework::proto::OpDesc& op,
25-
const framework::Scope& scope, bool test_mode) override {
25+
const framework::Scope& scope,
26+
bool test_mode) override {
2627
#if IS_TRT_VERSION_GE(6000)
2728
VLOG(4) << "convert fused skip layernorm op to tensorrt layer";
2829
framework::OpDesc op_desc(op, nullptr);
@@ -63,7 +64,8 @@ class SkipLayerNormOpConverter : public OpConverter {
6364
auto creator = GetPluginRegistry()->getPluginCreator(
6465
"CustomSkipLayerNormPluginDynamic", "3");
6566
PADDLE_ENFORCE_NE(
66-
creator, nullptr,
67+
creator,
68+
nullptr,
6769
platform::errors::InvalidArgument(
6870
"fail to get creator of CustomSkipLayerNormPluginDynamic"));
6971
const std::vector<nvinfer1::PluginField> fields{
@@ -85,22 +87,25 @@ class SkipLayerNormOpConverter : public OpConverter {
8587
inputs.data(), inputs.size(), *pluginObj);
8688

8789
PADDLE_ENFORCE_NE(
88-
plugin_layer, nullptr,
90+
plugin_layer,
91+
nullptr,
8992
platform::errors::InvalidArgument(
9093
"fail to add CustomSkipLayerNormPluginDynamic layer"));
9194
layer = plugin_layer;
9295
} else {
9396
auto creator = GetPluginRegistry()->getPluginCreator(
9497
"CustomSkipLayerNormPluginDynamic", "2");
9598
PADDLE_ENFORCE_NE(
96-
creator, nullptr,
99+
creator,
100+
nullptr,
97101
platform::errors::InvalidArgument(
98102
"fail to get creator of CustomSkipLayerNormPluginDynamic"));
99103
int type = static_cast<int>((engine_->WithFp16() == 1)
100104
? nvinfer1::DataType::kHALF
101105
: nvinfer1::DataType::kFLOAT);
102106
int ld = input1->getDimensions().d[2]; // hidden dimension
103-
PADDLE_ENFORCE_GT(ld, 0,
107+
PADDLE_ENFORCE_GT(ld,
108+
0,
104109
platform::errors::InvalidArgument(
105110
"in CustomSkipLayerNormPluginDynamic hidden "
106111
"dimension should > 0"));
@@ -128,18 +133,21 @@ class SkipLayerNormOpConverter : public OpConverter {
128133
inputs.data(), inputs.size(), *pluginObj);
129134

130135
PADDLE_ENFORCE_NE(
131-
plugin_layer, nullptr,
136+
plugin_layer,
137+
nullptr,
132138
platform::errors::InvalidArgument(
133139
"fail to add CustomSkipLayerNormPluginDynamic layer"));
134140
layer = plugin_layer;
135141
}
136142
} else {
137143
float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon"));
138-
bool with_fp16 =
139-
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
144+
/* bool with_fp16 =
145+
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
146+
*/
147+
bool with_fp16 = false;
140148
plugin::SkipLayerNormPluginDynamic* plugin =
141-
new plugin::SkipLayerNormPluginDynamic(bias, scale, bias_size,
142-
scale_size, eps, with_fp16);
149+
new plugin::SkipLayerNormPluginDynamic(
150+
bias, scale, bias_size, scale_size, eps, with_fp16);
143151
layer = engine_->AddDynamicPlugin(inputs.data(), 2, plugin);
144152
}
145153

0 commit comments

Comments
 (0)