@@ -4031,6 +4031,43 @@ struct LogitOpTranscriber : public OpTranscriber {
40314031 }
40324032};
40334033
4034+ struct Pad3dOpTranscriber : public OpTranscriber {
4035+ pir::AttributeMap TranslateOpAttribute (
4036+ pir::IrContext* ctx,
4037+ const std::string& normalized_op_name,
4038+ const OpAttributeInfoList& op_attr_infos,
4039+ const OpDesc& op_desc) override {
4040+ auto & attribute_translator = AttributeTranslator::instance ();
4041+ auto & op_normalizer = OpNameNormalizer::instance ();
4042+ pir::AttributeMap attribute_map = {};
4043+
4044+ for (const auto & info : op_attr_infos) {
4045+ auto legacy_attr_name =
4046+ op_normalizer.GetLegacyAttrName (op_desc.Type (), info.name );
4047+ VLOG (10 ) << " [op: " << op_desc.Type ()
4048+ << " ][attr] from: " << legacy_attr_name << " to: " << info.name ;
4049+ if (op_desc.HasAttr (legacy_attr_name)) {
4050+ paddle::framework::Attribute legacy_attr =
4051+ op_desc.GetAttr (legacy_attr_name);
4052+ VLOG (10 ) << " attribute in " << op_desc.Type ()
4053+ << " name: " << legacy_attr_name << " " << legacy_attr.index ();
4054+ pir::Attribute new_attr =
4055+ attribute_translator (info.type_name , legacy_attr);
4056+ if (info.name == " pad_value" ) {
4057+ new_attr = pir::DoubleAttribute::get (
4058+ ctx,
4059+ static_cast <double >(
4060+ new_attr.dyn_cast <pir::FloatAttribute>().data ()));
4061+ }
4062+ attribute_map[info.name ] = new_attr;
4063+ } else {
4064+ this ->HandleNonexistentAttribute (ctx, &attribute_map, info);
4065+ }
4066+ }
4067+ return attribute_map;
4068+ }
4069+ };
4070+
40344071OpTranslator::OpTranslator () {
40354072 pir::IrContext* ctx = pir::IrContext::Instance ();
40364073 ctx->GetOrRegisterDialect <paddle::dialect::OperatorDialect>();
@@ -4149,5 +4186,7 @@ OpTranslator::OpTranslator() {
41494186 special_handlers[" softplus_grad" ] = SoftPlusOpTranscriber ();
41504187 special_handlers[" logit" ] = LogitOpTranscriber ();
41514188 special_handlers[" logit_grad" ] = LogitOpTranscriber ();
4189+ special_handlers[" pad3d" ] = Pad3dOpTranscriber ();
4190+ special_handlers[" pad3d_grad" ] = Pad3dOpTranscriber ();
41524191}
41534192} // namespace paddle::translator
0 commit comments