@@ -22,7 +22,8 @@ namespace tensorrt {
22
22
class SkipLayerNormOpConverter : public OpConverter {
23
23
public:
24
24
void operator ()(const framework::proto::OpDesc& op,
25
- const framework::Scope& scope, bool test_mode) override {
25
+ const framework::Scope& scope,
26
+ bool test_mode) override {
26
27
#if IS_TRT_VERSION_GE(6000)
27
28
VLOG (4 ) << " convert fused skip layernorm op to tensorrt layer" ;
28
29
framework::OpDesc op_desc (op, nullptr );
@@ -63,7 +64,8 @@ class SkipLayerNormOpConverter : public OpConverter {
63
64
auto creator = GetPluginRegistry ()->getPluginCreator (
64
65
" CustomSkipLayerNormPluginDynamic" , " 3" );
65
66
PADDLE_ENFORCE_NE (
66
- creator, nullptr ,
67
+ creator,
68
+ nullptr ,
67
69
platform::errors::InvalidArgument (
68
70
" fail to get creator of CustomSkipLayerNormPluginDynamic" ));
69
71
const std::vector<nvinfer1::PluginField> fields{
@@ -85,22 +87,25 @@ class SkipLayerNormOpConverter : public OpConverter {
85
87
inputs.data (), inputs.size (), *pluginObj);
86
88
87
89
PADDLE_ENFORCE_NE (
88
- plugin_layer, nullptr ,
90
+ plugin_layer,
91
+ nullptr ,
89
92
platform::errors::InvalidArgument (
90
93
" fail to add CustomSkipLayerNormPluginDynamic layer" ));
91
94
layer = plugin_layer;
92
95
} else {
93
96
auto creator = GetPluginRegistry ()->getPluginCreator (
94
97
" CustomSkipLayerNormPluginDynamic" , " 2" );
95
98
PADDLE_ENFORCE_NE (
96
- creator, nullptr ,
99
+ creator,
100
+ nullptr ,
97
101
platform::errors::InvalidArgument (
98
102
" fail to get creator of CustomSkipLayerNormPluginDynamic" ));
99
103
int type = static_cast <int >((engine_->WithFp16 () == 1 )
100
104
? nvinfer1::DataType::kHALF
101
105
: nvinfer1::DataType::kFLOAT );
102
106
int ld = input1->getDimensions ().d [2 ]; // hidden dimension
103
- PADDLE_ENFORCE_GT (ld, 0 ,
107
+ PADDLE_ENFORCE_GT (ld,
108
+ 0 ,
104
109
platform::errors::InvalidArgument (
105
110
" in CustomSkipLayerNormPluginDynamic hidden "
106
111
" dimension should > 0" ));
@@ -128,18 +133,21 @@ class SkipLayerNormOpConverter : public OpConverter {
128
133
inputs.data (), inputs.size (), *pluginObj);
129
134
130
135
PADDLE_ENFORCE_NE (
131
- plugin_layer, nullptr ,
136
+ plugin_layer,
137
+ nullptr ,
132
138
platform::errors::InvalidArgument (
133
139
" fail to add CustomSkipLayerNormPluginDynamic layer" ));
134
140
layer = plugin_layer;
135
141
}
136
142
} else {
137
143
float eps = BOOST_GET_CONST (float , op_desc.GetAttr (" epsilon" ));
138
- bool with_fp16 =
139
- engine_->WithFp16 () && !engine_->disable_trt_plugin_fp16 ();
144
+ /* bool with_fp16 =
145
+ engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
146
+ */
147
+ bool with_fp16 = false ;
140
148
plugin::SkipLayerNormPluginDynamic* plugin =
141
- new plugin::SkipLayerNormPluginDynamic (bias, scale, bias_size,
142
- scale_size, eps, with_fp16);
149
+ new plugin::SkipLayerNormPluginDynamic (
150
+ bias, scale, bias_size, scale_size, eps, with_fp16);
143
151
layer = engine_->AddDynamicPlugin (inputs.data (), 2 , plugin);
144
152
}
145
153
0 commit comments