Skip to content

Commit bb79a54

Browse files
[Precision Depth Alignment] Change the pad_value parameter of pad3d from float to double (PaddlePaddle#75970)
* Change the pad_falue parameter of pad3d from float to double * fix Pad3dInferMeta * fix1 * fix2 * fix3 * fix * fix op_translator * fix Pad3dOpTranscriber
1 parent b4a5484 commit bb79a54

File tree

21 files changed

+103
-19
lines changed

21 files changed

+103
-19
lines changed

paddle/fluid/inference/tensorrt/pir/generic_plugin.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,8 +704,14 @@ int GenericPlugin::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
704704
phi_kernel_contexts_[data_type]->EmplaceBackAttr(
705705
attrs_map_[t].dyn_cast<::pir::FloatAttribute>().data());
706706
} else if (attr_type_name == "pir::DoubleAttribute") {
707-
phi_kernel_contexts_[data_type]->EmplaceBackAttr(
708-
attrs_map_[t].dyn_cast<::pir::DoubleAttribute>().data());
707+
if (attrs_map_[t].type_id() == ::pir::FloatAttribute::type_id()) {
708+
const auto val = attrs_map_[t].dyn_cast<::pir::FloatAttribute>().data();
709+
phi_kernel_contexts_[data_type]->EmplaceBackAttr(
710+
static_cast<double>(val));
711+
} else {
712+
phi_kernel_contexts_[data_type]->EmplaceBackAttr(
713+
attrs_map_[t].dyn_cast<::pir::DoubleAttribute>().data());
714+
}
709715
} else if (attr_type_name == "pir::BoolAttribute") {
710716
phi_kernel_contexts_[data_type]->EmplaceBackAttr(
711717
attrs_map_[t].dyn_cast<::pir::BoolAttribute>().data());

paddle/fluid/ir_adaptor/translator/op_translator.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4031,6 +4031,43 @@ struct LogitOpTranscriber : public OpTranscriber {
40314031
}
40324032
};
40334033

4034+
struct Pad3dOpTranscriber : public OpTranscriber {
4035+
pir::AttributeMap TranslateOpAttribute(
4036+
pir::IrContext* ctx,
4037+
const std::string& normalized_op_name,
4038+
const OpAttributeInfoList& op_attr_infos,
4039+
const OpDesc& op_desc) override {
4040+
auto& attribute_translator = AttributeTranslator::instance();
4041+
auto& op_normalizer = OpNameNormalizer::instance();
4042+
pir::AttributeMap attribute_map = {};
4043+
4044+
for (const auto& info : op_attr_infos) {
4045+
auto legacy_attr_name =
4046+
op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name);
4047+
VLOG(10) << "[op: " << op_desc.Type()
4048+
<< "][attr] from: " << legacy_attr_name << " to: " << info.name;
4049+
if (op_desc.HasAttr(legacy_attr_name)) {
4050+
paddle::framework::Attribute legacy_attr =
4051+
op_desc.GetAttr(legacy_attr_name);
4052+
VLOG(10) << "attribute in " << op_desc.Type()
4053+
<< " name: " << legacy_attr_name << " " << legacy_attr.index();
4054+
pir::Attribute new_attr =
4055+
attribute_translator(info.type_name, legacy_attr);
4056+
if (info.name == "pad_value") {
4057+
new_attr = pir::DoubleAttribute::get(
4058+
ctx,
4059+
static_cast<double>(
4060+
new_attr.dyn_cast<pir::FloatAttribute>().data()));
4061+
}
4062+
attribute_map[info.name] = new_attr;
4063+
} else {
4064+
this->HandleNonexistentAttribute(ctx, &attribute_map, info);
4065+
}
4066+
}
4067+
return attribute_map;
4068+
}
4069+
};
4070+
40344071
OpTranslator::OpTranslator() {
40354072
pir::IrContext* ctx = pir::IrContext::Instance();
40364073
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
@@ -4149,5 +4186,7 @@ OpTranslator::OpTranslator() {
41494186
special_handlers["softplus_grad"] = SoftPlusOpTranscriber();
41504187
special_handlers["logit"] = LogitOpTranscriber();
41514188
special_handlers["logit_grad"] = LogitOpTranscriber();
4189+
special_handlers["pad3d"] = Pad3dOpTranscriber();
4190+
special_handlers["pad3d_grad"] = Pad3dOpTranscriber();
41524191
}
41534192
} // namespace paddle::translator

paddle/fluid/pir/serialize_deserialize/patch/0.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,8 @@ op_patches:
3131
- action : modify_attr
3232
object : eps
3333
type : pir::DoubleAttribute
34+
- op_name : pd_op.pad3d
35+
actions:
36+
- action : modify_attr
37+
object : pad_value
38+
type : pir::DoubleAttribute

paddle/phi/infermeta/unary.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3428,7 +3428,7 @@ void PadInferMeta(const MetaTensor& input,
34283428
void Pad3dInferMeta(const MetaTensor& x,
34293429
const IntArray& paddings_int_array,
34303430
const std::string& mode,
3431-
float value,
3431+
double value,
34323432
const std::string& data_format,
34333433
MetaTensor* out,
34343434
MetaConfig config) {

paddle/phi/infermeta/unary.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ PADDLE_API void PadInferMeta(const MetaTensor& input,
559559
PADDLE_API void Pad3dInferMeta(const MetaTensor& x,
560560
const IntArray& paddings,
561561
const std::string& mode,
562-
float value,
562+
double value,
563563
const std::string& data_format,
564564
MetaTensor* out,
565565
MetaConfig config = MetaConfig());

paddle/phi/kernels/cpu/pad3d_grad_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ void Pad3dGradKernel(const Context& dev_ctx,
364364
const DenseTensor& out_grad,
365365
const IntArray& paddings,
366366
const std::string& mode,
367-
float pad_value UNUSED,
367+
double pad_value UNUSED,
368368
const std::string& data_format,
369369
DenseTensor* x_grad) {
370370
std::vector<int64_t> pads = paddings.GetData();

paddle/phi/kernels/cpu/pad3d_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ void Pad3dKernel(const Context& dev_ctx,
381381
const DenseTensor& x,
382382
const IntArray& paddings,
383383
const std::string& mode,
384-
float pad_value,
384+
double pad_value,
385385
const std::string& data_format,
386386
DenseTensor* out) {
387387
T value = static_cast<T>(pad_value);

paddle/phi/kernels/gpu/pad3d_grad_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ void Pad3dGradKernel(const Context& dev_ctx,
343343
const DenseTensor& out_grad,
344344
const IntArray& paddings,
345345
const std::string& mode,
346-
float pad_value,
346+
double pad_value,
347347
const std::string& data_format,
348348
DenseTensor* x_grad) {
349349
std::vector<int64_t> pads = paddings.GetData();

paddle/phi/kernels/gpu/pad3d_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ void Pad3dKernel(const Context& dev_ctx,
333333
const DenseTensor& x,
334334
const IntArray& paddings,
335335
const std::string& mode,
336-
float pad_value,
336+
double pad_value,
337337
const std::string& data_format,
338338
DenseTensor* out) {
339339
std::vector<int64_t> pads = paddings.GetData();

paddle/phi/kernels/onednn/pad3d_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ void Pad3dKernel(const Context& dev_ctx,
5252
const DenseTensor& x,
5353
const IntArray& paddings,
5454
const std::string& mode UNUSED,
55-
float pad_value,
55+
double pad_value,
5656
const std::string& data_format UNUSED,
5757
DenseTensor* out) {
5858
PadOpKernel<T, Context>(dev_ctx, x, paddings.GetData(), pad_value, out);

0 commit comments

Comments
 (0)