@@ -90,12 +90,6 @@ nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic<T>::getOutputDimensions(
90
90
" so the index should be zero,"
91
91
" but it's (%d)" ,
92
92
output_index));
93
- PADDLE_ENFORCE_EQ (
94
- nb_inputs, 3 ,
95
- platform::errors::InvalidArgument (
96
- " The Input of the EmbEltwiseLayernorm should be 3, but we found "
97
- " it has (%d) inputs" ,
98
- nb_inputs));
99
93
nvinfer1::DimsExprs ret;
100
94
ret.nbDims = 5 ;
101
95
ret.d [0 ] = inputs[0 ].d [0 ];
@@ -113,13 +107,18 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
113
107
PADDLE_ENFORCE_NOT_NULL (
114
108
in_out, platform::errors::InvalidArgument (
115
109
" The input of swish plugin shoule not be nullptr." ));
116
-
110
+ PADDLE_ENFORCE_EQ (nb_outputs, 1 ,
111
+ platform::errors::InvalidArgument (
112
+ " The EmbEltwiseLayerNorm's output should be one"
113
+ " but it's (%d) outputs." ,
114
+ nb_outputs));
117
115
PADDLE_ENFORCE_LT (
118
116
pos, nb_inputs + nb_outputs,
119
117
platform::errors::InvalidArgument (" The pos(%d) should be less than the "
120
118
" num(%d) of the input and the output." ,
121
119
pos, nb_inputs + nb_outputs));
122
- (in_out && pos < (nb_inputs + nb_outputs));
120
+
121
+ int all_nums = nb_inputs + nb_outputs;
123
122
124
123
const nvinfer1::PluginTensorDesc &desc = in_out[pos];
125
124
if (desc.format != nvinfer1::TensorFormat::kLINEAR ) {
@@ -131,18 +130,19 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
131
130
}
132
131
133
132
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1 ];
134
- if (pos == 1 || pos == 2 ) {
133
+ if (pos < all_nums - 1 ) {
135
134
return desc.type == nvinfer1::DataType::kINT32 &&
136
135
desc.dims .d [0 ] == prev.dims .d [0 ] && desc.dims .d [1 ] == prev.dims .d [1 ];
137
136
}
138
137
139
- if (pos == 3 ) {
138
+ if (pos == all_nums - 1 ) {
140
139
if (sizeof (T) == sizeof (float )) {
141
140
return desc.type == nvinfer1::DataType::kFLOAT ;
142
141
} else {
143
142
return desc.type == nvinfer1::DataType::kHALF ;
144
143
}
145
144
}
145
+ return false ;
146
146
}
147
147
148
148
template <typename T>
0 commit comments