Skip to content

Commit 914fd81

Browse files
authored
cherry-pick :fix emb eltwise layernorm (#24873) (#25363)
test=release/1.8
1 parent 06c86d4 commit 914fd81

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,6 @@ nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic<T>::getOutputDimensions(
9090
"so the index should be zero,"
9191
"but it's (%d)",
9292
output_index));
93-
PADDLE_ENFORCE_EQ(
94-
nb_inputs, 3,
95-
platform::errors::InvalidArgument(
96-
"The Input of the EmbEltwiseLayernorm should be 3, but we found "
97-
"it has (%d) inputs",
98-
nb_inputs));
9993
nvinfer1::DimsExprs ret;
10094
ret.nbDims = 5;
10195
ret.d[0] = inputs[0].d[0];
@@ -113,13 +107,18 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
113107
PADDLE_ENFORCE_NOT_NULL(
114108
in_out, platform::errors::InvalidArgument(
115109
"The input of swish plugin shoule not be nullptr."));
116-
110+
PADDLE_ENFORCE_EQ(nb_outputs, 1,
111+
platform::errors::InvalidArgument(
112+
"The EmbEltwiseLayerNorm's output should be one"
113+
"but it's (%d) outputs.",
114+
nb_outputs));
117115
PADDLE_ENFORCE_LT(
118116
pos, nb_inputs + nb_outputs,
119117
platform::errors::InvalidArgument("The pos(%d) should be less than the "
120118
"num(%d) of the input and the output.",
121119
pos, nb_inputs + nb_outputs));
122-
(in_out && pos < (nb_inputs + nb_outputs));
120+
121+
int all_nums = nb_inputs + nb_outputs;
123122

124123
const nvinfer1::PluginTensorDesc &desc = in_out[pos];
125124
if (desc.format != nvinfer1::TensorFormat::kLINEAR) {
@@ -131,18 +130,19 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
131130
}
132131

133132
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
134-
if (pos == 1 || pos == 2) {
133+
if (pos < all_nums - 1) {
135134
return desc.type == nvinfer1::DataType::kINT32 &&
136135
desc.dims.d[0] == prev.dims.d[0] && desc.dims.d[1] == prev.dims.d[1];
137136
}
138137

139-
if (pos == 3) {
138+
if (pos == all_nums - 1) {
140139
if (sizeof(T) == sizeof(float)) {
141140
return desc.type == nvinfer1::DataType::kFLOAT;
142141
} else {
143142
return desc.type == nvinfer1::DataType::kHALF;
144143
}
145144
}
145+
return false;
146146
}
147147

148148
template <typename T>

0 commit comments

Comments
 (0)