Skip to content

Commit 1d25c66

Browse files
authored
[Cherry-pick][Paddle Inference] fix mixed precision diff (#49477)
* disable scale op in amp pass * Do not insert redundant cast op * fix fused_fc_elementwise_layernorm kernel diff * fix fc kerenl diff
1 parent 7696ae0 commit 1d25c66

File tree

4 files changed

+115
-79
lines changed

4 files changed

+115
-79
lines changed

paddle/fluid/framework/ir/auto_mixed_precision_pass.cc

Lines changed: 107 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ inline bool VarNodeHasDtype(Node* var_node) {
7878
(type == VarType::VOCAB);
7979
}
8080

81-
inline bool IsFloatType(VarType::Type type) {
81+
inline bool IsFP32AndFP64(VarType::Type type) {
8282
return (type == VarType::FP64) || (type == VarType::FP32);
8383
}
8484

85-
inline bool IsHalfType(VarType::Type type) {
85+
inline bool IsFP16AndBFP16(VarType::Type type) {
8686
return (type == VarType::FP16) || (type == VarType::BF16);
8787
}
8888

@@ -159,26 +159,16 @@ bool OpSupportPrecision(const std::string& op_type,
159159
// The set of ops that support fp16 calculation and are considered
160160
// numerically-dangerous, slower and whose effects may also be observed in
161161
// downstream ops.
162+
// ref to python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
162163
void AutoMixedPrecisionPass::SetDefaultBlacklist() const {
163164
black_list_.insert({
164165
// numerically-dangerous
165-
"acos",
166-
"asin",
167-
"cosh",
168-
"tan",
169166
"exp",
170-
"expm1",
171167
"square",
172168
"log",
173-
"log2",
174-
"log10",
175-
"log1p",
176-
"logsumexp",
177169
"mean",
178-
"rsqrt",
179170
"sum",
180171
"cos_sim",
181-
"softmax",
182172
"softmax_with_cross_entropy",
183173
"sigmoid_cross_entropy_with_logits",
184174
"c_softmax_with_cross_entropy",
@@ -272,6 +262,9 @@ void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const {
272262
VLOG(4) << "InsertCastOp done";
273263
RestoreOpOriginType();
274264
VLOG(4) << "RestoreOpOriginType done";
265+
LOG(INFO) << "The number of ops run at low precision ["
266+
<< op_run_low_precision_.size() << "/" << op_original_type_.size()
267+
<< "]";
275268
}
276269

277270
void AutoMixedPrecisionPass::SetOpUniqueType() const {
@@ -315,22 +308,36 @@ void AutoMixedPrecisionPass::ProcessOpWithDtypeAttr() const {
315308
for (const auto& nodes : all_op_nodes_) {
316309
for (auto* op_node : nodes) {
317310
auto op_type = op_node->Op()->Type();
311+
312+
if (op_node->Op()->HasAttr("in_dtype")) {
313+
auto* var_node = op_node->inputs[0];
314+
auto* real_var_node = real_vars_[var_node->Var()->Name()];
315+
if (IsFP16AndBFP16(real_var_node->Var()->GetDataType())) {
316+
op_node->Op()->SetAttr(
317+
"in_dtype",
318+
static_cast<int>(framework::TransToProtoVarType(low_precision_)));
319+
op_node->Op()->Flush();
320+
VLOG(4) << "process op with in_dtype attr: " << op_type << " ( "
321+
<< static_cast<int>(real_var_node->Var()->GetDataType())
322+
<< " --->" << static_cast<int>(low_precision_) << " )";
323+
}
324+
}
325+
318326
if (op_run_low_precision_.count(op_type) == 0) continue;
319327

320328
if (op_node->Op()->HasAttr("dtype")) {
321329
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
322-
if (IsFloatType(static_cast<VarType::Type>(dtype))) {
330+
if (IsFP32AndFP64(static_cast<VarType::Type>(dtype))) {
323331
op_node->Op()->SetAttr(
324332
"dtype",
325333
static_cast<int>(framework::TransToProtoVarType(low_precision_)));
326334
op_node->Op()->Flush();
327335
VLOG(4) << "process op with dtype attr: " << op_type << " ( " << dtype
328336
<< " --->" << static_cast<int>(low_precision_) << " )";
329337
}
330-
}
331-
if (op_node->Op()->HasAttr("out_dtype")) {
338+
} else if (op_node->Op()->HasAttr("out_dtype")) {
332339
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
333-
if (IsFloatType(static_cast<VarType::Type>(out_dtype))) {
340+
if (IsFP32AndFP64(static_cast<VarType::Type>(out_dtype))) {
334341
op_node->Op()->SetAttr(
335342
"out_dtype",
336343
static_cast<int>(framework::TransToProtoVarType(low_precision_)));
@@ -359,37 +366,55 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
359366

360367
if (op_node->Op()->HasAttr("dtype")) {
361368
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
362-
support_low_precision = support_low_precision &&
363-
IsFloatType(static_cast<VarType::Type>(dtype));
369+
support_low_precision =
370+
support_low_precision &&
371+
IsFP32AndFP64(static_cast<VarType::Type>(dtype));
364372
} else if (op_node->Op()->HasAttr("out_dtype")) {
365373
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
366374
support_low_precision =
367375
support_low_precision &&
368-
IsFloatType(static_cast<VarType::Type>(out_dtype));
369-
} else {
370-
// if op's input var and output var is not dense tensor, the op should
371-
// not run at low precision.
372-
for (auto* in_var_node : op_node->inputs) {
373-
CHECK_EQ(in_var_node->IsVar(), true);
374-
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
375-
if (real_in_var_node->Var()->Persistable()) continue;
376+
IsFP32AndFP64(static_cast<VarType::Type>(out_dtype));
377+
}
376378

379+
// If scale op's "scale" and "bias" attr value exceed the range of fp16
380+
// and bf16, it cannot run at low precision.
381+
if (GetOpOriginalType(op_node->Op()->Type()) == "scale") {
382+
auto scale = op_node->Op()->GetAttrIfExists<float>("scale");
383+
auto bias = op_node->Op()->GetAttrIfExists<float>("bias");
384+
if (low_precision_ == phi::DataType::FLOAT16) {
377385
support_low_precision =
378386
support_low_precision &&
379-
(real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR);
380-
}
381-
382-
for (auto* out_var_node : op_node->outputs) {
383-
CHECK_EQ(out_var_node->IsVar(), true);
384-
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
385-
if (real_out_var_node->Var()->Persistable()) continue;
386-
387+
phi::dtype::isfinite(static_cast<phi::dtype::float16>(scale)) &&
388+
phi::dtype::isfinite(static_cast<phi::dtype::float16>(bias));
389+
} else if (low_precision_ == phi::DataType::BFLOAT16) {
387390
support_low_precision =
388391
support_low_precision &&
389-
(real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR);
392+
phi::dtype::isfinite(static_cast<phi::dtype::bfloat16>(scale)) &&
393+
phi::dtype::isfinite(static_cast<phi::dtype::bfloat16>(bias));
390394
}
391395
}
392396

397+
// if op's input var and output var is not dense tensor, the op should
398+
// not run at low precision.
399+
for (auto* in_var_node : op_node->inputs) {
400+
CHECK_EQ(in_var_node->IsVar(), true);
401+
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
402+
if (real_in_var_node->Var()->Persistable()) continue;
403+
404+
support_low_precision =
405+
support_low_precision &&
406+
(real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR);
407+
}
408+
for (auto* out_var_node : op_node->outputs) {
409+
CHECK_EQ(out_var_node->IsVar(), true);
410+
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
411+
if (real_out_var_node->Var()->Persistable()) continue;
412+
413+
support_low_precision =
414+
support_low_precision &&
415+
(real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR);
416+
}
417+
393418
if (support_low_precision) {
394419
op_run_low_precision_.insert(op_type);
395420
VLOG(4) << "support precision: " << op_type << " run at low precision";
@@ -439,7 +464,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
439464
}
440465

441466
// when op_1 only support cpu kernel. if op_2's intput var is op_1's
442-
// output var, then op_2 should not run half.
467+
// output var, then op_2 should not run at low precision.
443468
if (GetOpOriginalType(op_type) != "feed" &&
444469
!GpuKernelSupportPrecision(GetOpOriginalType(op_type),
445470
phi::DataType::FLOAT32)) {
@@ -597,7 +622,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const {
597622
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
598623
auto in_var_name = real_in_var_node->Var()->Name();
599624

600-
if (!IsFloatType(real_in_var_node->Var()->GetDataType())) continue;
625+
if (!IsFP32AndFP64(real_in_var_node->Var()->GetDataType())) continue;
601626
if (!VarNodeHasDtype(real_in_var_node)) continue;
602627
if (InputVarsNotConvert(op_node, in_var_name)) continue;
603628

@@ -616,7 +641,7 @@ void AutoMixedPrecisionPass::SetVarPrecision() const {
616641
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
617642
auto out_var_name = real_out_var_node->Var()->Name();
618643

619-
if (!IsFloatType(real_out_var_node->Var()->GetDataType())) continue;
644+
if (!IsFP32AndFP64(real_out_var_node->Var()->GetDataType())) continue;
620645
if (!VarNodeHasDtype(real_out_var_node)) continue;
621646
if (OutputVarsNotConvert(op_node, out_var_name)) continue;
622647

@@ -656,7 +681,7 @@ void AutoMixedPrecisionPass::ConvertWeightsData() const {
656681
auto var_names = scope->LocalVarNames();
657682
for (const auto& var_name : var_names) {
658683
if (vars_convert_to_low_precision_.count(var_name)) {
659-
VLOG(4) << var_name << "'s data type was convert to half";
684+
VLOG(4) << var_name << "'s data type was convert to low precision";
660685

661686
auto* var = scope->FindLocalVar(var_name);
662687
CHECK_EQ(var->IsType<phi::DenseTensor>(), true);
@@ -683,16 +708,18 @@ void AutoMixedPrecisionPass::ConvertWeightsData() const {
683708
}
684709
}
685710
} else if (low_precision_ == phi::DataType::BFLOAT16) {
686-
auto* half_data =
711+
auto* low_precision_data =
687712
low_precision_tensor.mutable_data<phi::dtype::bfloat16>(
688713
phi::CPUPlace{});
689714
for (int64_t i = 0; i < origin_tensor->numel(); i++) {
690715
if (origin_tensor->dtype() == phi::DataType::FLOAT64) {
691716
auto* origin_data = origin_tensor->data<double>();
692-
half_data[i] = static_cast<phi::dtype::bfloat16>(origin_data[i]);
717+
low_precision_data[i] =
718+
static_cast<phi::dtype::bfloat16>(origin_data[i]);
693719
} else if (origin_tensor->dtype() == phi::DataType::FLOAT32) {
694720
auto* origin_data = origin_tensor->data<float>();
695-
half_data[i] = static_cast<phi::dtype::bfloat16>(origin_data[i]);
721+
low_precision_data[i] =
722+
static_cast<phi::dtype::bfloat16>(origin_data[i]);
696723
}
697724
}
698725
}
@@ -732,25 +759,44 @@ void AutoMixedPrecisionPass::InsertCastOp() const {
732759
VLOG(4) << "process var: " << real_in_var_node->Var()->Name()
733760
<< " with type " << in_var_type;
734761

735-
if (IsFloatType(in_var_type) && op_run_low_precision_.count(op_type)) {
736-
DoInsertCastOp(subgraphes_[i],
737-
in_var_node,
738-
op_node,
739-
in_var_type,
740-
framework::TransToProtoVarType(low_precision_),
741-
block_desc,
742-
&suffix,
743-
&cache);
744-
} else if (IsHalfType(in_var_type) &&
762+
if (IsFP32AndFP64(in_var_type) &&
763+
op_run_low_precision_.count(op_type)) {
764+
auto to_type = framework::TransToProtoVarType(low_precision_);
765+
auto* prev_op =
766+
in_var_node->inputs.empty() ? nullptr : in_var_node->inputs[0];
767+
if (prev_op && GetOpOriginalType(prev_op->Op()->Type()) == "cast") {
768+
in_var_node->Var()->SetDataType(to_type);
769+
prev_op->Op()->SetAttr("out_dtype", static_cast<int>(to_type));
770+
prev_op->Op()->Flush();
771+
} else {
772+
DoInsertCastOp(subgraphes_[i],
773+
in_var_node,
774+
op_node,
775+
in_var_type,
776+
to_type,
777+
block_desc,
778+
&suffix,
779+
&cache);
780+
}
781+
} else if (IsFP16AndBFP16(in_var_type) &&
745782
op_run_low_precision_.count(op_type) == 0) {
746-
DoInsertCastOp(subgraphes_[i],
747-
in_var_node,
748-
op_node,
749-
in_var_type,
750-
VarType::FP32,
751-
block_desc,
752-
&suffix,
753-
&cache);
783+
auto to_type = VarType::FP32;
784+
auto* prev_op =
785+
in_var_node->inputs.empty() ? nullptr : in_var_node->inputs[0];
786+
if (prev_op && GetOpOriginalType(prev_op->Op()->Type()) == "cast") {
787+
in_var_node->Var()->SetDataType(to_type);
788+
prev_op->Op()->SetAttr("out_dtype", static_cast<int>(to_type));
789+
prev_op->Op()->Flush();
790+
} else {
791+
DoInsertCastOp(subgraphes_[i],
792+
in_var_node,
793+
op_node,
794+
in_var_type,
795+
to_type,
796+
block_desc,
797+
&suffix,
798+
&cache);
799+
}
754800
}
755801
}
756802

paddle/fluid/inference/tests/api/gpu_ernie_half_test.cc

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ TEST(Ernie_gpu_fp16_no_ir, compare_results) {
164164
}
165165
float *result = reinterpret_cast<float *>(output.data.data());
166166
for (size_t j = 0; j < outputs_size; ++j) {
167-
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2);
167+
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 8e-3);
168168
}
169169
}
170170
}
@@ -175,8 +175,6 @@ TEST(Ernie_gpu_fp16_with_ir, compare_results) {
175175
config.SetModel(FLAGS_infer_model);
176176
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kHalf);
177177
config.SwitchIrOptim(true);
178-
// The fc_fuse_pass has diff, which will be repaired later.
179-
config.pass_builder()->DeletePass("fc_fuse_pass");
180178
// There is a problem with the model itself, which has nothing to do with
181179
// constant_folding_pass.
182180
config.pass_builder()->DeletePass("constant_folding_pass");
@@ -206,7 +204,7 @@ TEST(Ernie_gpu_fp16_with_ir, compare_results) {
206204
}
207205
float *result = reinterpret_cast<float *>(output.data.data());
208206
for (size_t j = 0; j < outputs_size; ++j) {
209-
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-2);
207+
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 2e-2);
210208
}
211209
}
212210
}
@@ -243,7 +241,7 @@ TEST(Ernie_gpu_bf16_no_ir, compare_results) {
243241
}
244242
float *result = reinterpret_cast<float *>(output.data.data());
245243
for (size_t j = 0; j < outputs_size; ++j) {
246-
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2);
244+
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 1e-2);
247245
}
248246
}
249247
}
@@ -254,8 +252,6 @@ TEST(Ernie_gpu_bf16_with_ir, compare_results) {
254252
config.SetModel(FLAGS_infer_model);
255253
config.EnableUseGpu(512, 0, paddle_infer::PrecisionType::kBf16);
256254
config.SwitchIrOptim(true);
257-
// The fc_fuse_pass has diff, which will be repaired later.
258-
config.pass_builder()->DeletePass("fc_fuse_pass");
259255
// There is a problem with the model itself, which has nothing to do with
260256
// constant_folding_pass.
261257
config.pass_builder()->DeletePass("constant_folding_pass");
@@ -285,7 +281,7 @@ TEST(Ernie_gpu_bf16_with_ir, compare_results) {
285281
}
286282
float *result = reinterpret_cast<float *>(output.data.data());
287283
for (size_t j = 0; j < outputs_size; ++j) {
288-
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 7e-2);
284+
EXPECT_NEAR(ref[i * outputs_size + j], result[j], 5e-3);
289285
}
290286
}
291287
}

paddle/fluid/operators/fused/fused_fc_elementwise_layernorm_op.cu

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,7 @@ __global__ void InplaceAddReluAddLayerNormKernel(const float16* y_data,
223223

224224
// For layer_norm, reduce to calculate mean and std
225225
sum_i += static_cast<float>(tmp_3);
226-
#if defined(PADDLE_WITH_CUDA) && __CUDA_ARCH__ >= 530
227-
square_sum_i += static_cast<float>(__hmul(tmp_3, tmp_3));
228-
#elif defined(PADDLE_WITH_CUDA)
229226
square_sum_i += static_cast<float>(tmp_3) * static_cast<float>(tmp_3);
230-
#else
231-
square_sum_i += static_cast<float>(tmp_3 * tmp_3);
232-
#endif
233227
}
234228
auto pair = BlockReduce(temp_storage)
235229
.Reduce(PairForLayerNorm<float>(sum_i, square_sum_i),
@@ -282,9 +276,9 @@ __global__ void InplaceAddReluAddLayerNormKernel(const float16* y_data,
282276
half tmp_0 = __hdiv(__hsub(save_ptr[save_index], mean_i), std_i);
283277
half tmp_1 = scale ? __hmul(scale[j], tmp_0) : tmp_0;
284278
#else
285-
half tmp_0 = static_cast<float>(static_cast<float>(save_ptr[save_index]) +
286-
static_cast<float>(mean_i) /
287-
static_cast<float>(std_i));
279+
half tmp_0 = static_cast<half>(static_cast<float>(save_ptr[save_index]) -
280+
static_cast<float>(mean_i) /
281+
static_cast<float>(std_i));
288282
half tmp_1 = scale ? static_cast<half>(static_cast<float>(scale[j]) *
289283
static_cast<float>(tmp_0))
290284
: tmp_0;

paddle/phi/kernels/funcs/fc_functor.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ __global__ void bias_relu_v2(const int num,
149149
#if __CUDA_ARCH__ >= 800
150150
packed_val = __hmax2(__half2(0, 0), packed_val);
151151
#elif __CUDA_ARCH__ >= 530
152-
packed_val = __hmul2(__hgt2(__half2(0, 0), packed_val), packed_val);
152+
packed_val = __hmul2(__hgt2(packed_val, __half2(0, 0)), packed_val);
153153
#else
154154
packed_val.x = static_cast<int>(static_cast<float>(packed_val.x) > 0) *
155155
static_cast<float>(packed_val.x);

0 commit comments

Comments
 (0)