@@ -426,6 +426,342 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
426426 }
427427 return failure ();
428428 });
429+ patterns.onOp (
430+ " Conv" , 11 , [](OpBinder binder, ConversionPatternRewriter &rewriter) {
431+ std::string autoPad;
432+ if (binder.customOpNameStringAttr (autoPad, " auto_pad" , " NOTSET" ))
433+ return failure ();
434+ if (autoPad != " NOTSET" ) {
435+ // TODO: Add support for `auto_pad` != "NOTSET"
436+ return rewriter.notifyMatchFailure (
437+ binder.op , " unsupported conversion: auto_pad != NOTSET" );
438+ }
439+
440+ Torch::ValueTensorType resultType;
441+ Value input, weight;
442+ int64_t group;
443+ if (binder.tensorOperands (input, weight) ||
444+ binder.s64IntegerAttr (group, " group" , 1 ) ||
445+ binder.tensorResultType (resultType))
446+ return failure ();
447+
448+ auto weightTensorType = weight.getType ().cast <Torch::ValueTensorType>();
449+ if (!weightTensorType || !weightTensorType.hasSizes ()) {
450+ return rewriter.notifyMatchFailure (
451+ binder.op , " Expected weight type having sizes" );
452+ }
453+ ArrayRef<int64_t > weightShape = weightTensorType.getSizes ();
454+ SmallVector<int64_t > kernelShape;
455+ if (binder.s64IntegerArrayAttr (kernelShape, " kernel_shape" , {}))
456+ return failure ();
457+ if (kernelShape.size ()) {
458+ if (kernelShape.size () != weightShape.size () - 2 ) {
459+ return rewriter.notifyMatchFailure (
460+ binder.op ,
461+ " unsupported conversion: kernel_shape list size should have "
462+ " number of values equal to weight_rank - 2" );
463+ } else {
464+ for (unsigned i = 0 ; i < kernelShape.size (); i++) {
465+ if (weightShape[i + 2 ] != kernelShape[i]) {
466+ return rewriter.notifyMatchFailure (
467+ binder.op , " unsupported conversion: kernel_shape value "
468+ " should be equal to the weight tensor shape" );
469+ }
470+ }
471+ }
472+ }
473+
474+ // Determine the rank of input tensor.
475+ std::optional<unsigned > maybeRank = Torch::getTensorRank (input);
476+ if (!maybeRank)
477+ return rewriter.notifyMatchFailure (binder.op ,
478+ " Unimplemented: unranked tensor" );
479+ unsigned rank = *maybeRank;
480+
481+ SmallVector<int64_t > padding, strides, dilations;
482+ SmallVector<int64_t > defaultPadding, defaultStrides, defaultDilations;
483+ for (unsigned i = 0 ; i < rank - 2 ; i++) {
484+ defaultPadding.push_back (0 );
485+ defaultStrides.push_back (1 );
486+ defaultDilations.push_back (1 );
487+ }
488+ // Padding for the beginning and ending along each spatial axis, it can
489+ // take any value greater than or equal to 0. The value represent the
490+ // number of pixels added to the beginning and end part of the
491+ // corresponding axis. pads format should be as follow [x1_begin,
492+ // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
493+ // at the beginning of axis i and xi_end, the number of pixels added at
494+ // the end of axis i.
495+ if (binder.s64IntegerArrayAttr (padding, " pads" , defaultPadding)) {
496+ return failure ();
497+ }
498+ if (padding.size () != rank - 2 && padding.size () != 2 * (rank - 2 )) {
499+ return rewriter.notifyMatchFailure (
500+ binder.op , " padding list size does not match the number of axes" );
501+ }
502+ if (binder.s64IntegerArrayAttr (dilations, " dilations" ,
503+ defaultDilations)) {
504+ return failure ();
505+ }
506+ if (dilations.size () != rank - 2 ) {
507+ return rewriter.notifyMatchFailure (
508+ binder.op ,
509+ " dilations list size does not match the number of axes" );
510+ }
511+ if (binder.s64IntegerArrayAttr (strides, " strides" , defaultStrides)) {
512+ return failure ();
513+ }
514+ if (strides.size () != rank - 2 ) {
515+ return rewriter.notifyMatchFailure (
516+ binder.op , " strides list size does not match the number of axes" );
517+ }
518+
519+ SmallVector<Value> cstPadding, cstStrides, cstDilations,
520+ cstOutputPadding;
521+ if (padding.size () != 2 * (rank - 2 )) {
522+ for (int64_t i : padding) {
523+ cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
524+ binder.getLoc (), rewriter.getI64IntegerAttr (i)));
525+ }
526+ } else {
527+ for (unsigned i = 0 ; i < padding.size () / 2 ; i++) {
528+ if (padding[i] != padding[i + (padding.size () / 2 )]) {
529+ // TODO: Add support for different padding values for the
530+ // beginning and ending along each spatial axis
531+ return rewriter.notifyMatchFailure (
532+ binder.op ,
533+ " unsupported conversion: padding values for the beginning "
534+ " and ending along each spatial axis must be equal" );
535+ }
536+ cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
537+ binder.getLoc (), rewriter.getI64IntegerAttr (padding[i])));
538+ }
539+ }
540+ for (int64_t i : dilations) {
541+ cstDilations.push_back (rewriter.create <Torch::ConstantIntOp>(
542+ binder.getLoc (), rewriter.getI64IntegerAttr (i)));
543+ }
544+ for (int64_t i : strides) {
545+ cstStrides.push_back (rewriter.create <Torch::ConstantIntOp>(
546+ binder.getLoc (), rewriter.getI64IntegerAttr (i)));
547+ }
548+ Value cstZero = rewriter.create <Torch::ConstantIntOp>(
549+ binder.getLoc (), rewriter.getI64IntegerAttr (0 ));
550+ cstOutputPadding = {cstZero, cstZero};
551+
552+ Value paddingList = rewriter.create <Torch::PrimListConstructOp>(
553+ binder.getLoc (),
554+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
555+ cstPadding);
556+ Value dilationsList = rewriter.create <Torch::PrimListConstructOp>(
557+ binder.getLoc (),
558+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
559+ cstDilations);
560+ Value stridesList = rewriter.create <Torch::PrimListConstructOp>(
561+ binder.getLoc (),
562+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
563+ cstStrides);
564+ Value outputPaddingList = rewriter.create <Torch::PrimListConstructOp>(
565+ binder.getLoc (),
566+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
567+ cstOutputPadding);
568+ Value transposed =
569+ rewriter.create <Torch::ConstantBoolOp>(binder.getLoc (), false );
570+ Value bias;
571+ if (binder.op ->getNumOperands () == 3 ) {
572+ if (binder.tensorOperandAtIndex (bias, 2 )) {
573+ return failure ();
574+ }
575+ } else {
576+ bias = rewriter.create <Torch::ConstantNoneOp>(binder.getLoc ());
577+ }
578+ Value cstGroup = rewriter.create <Torch::ConstantIntOp>(
579+ binder.getLoc (), rewriter.getI64IntegerAttr (group));
580+
581+ rewriter.replaceOpWithNewOp <Torch::AtenConvolutionOp>(
582+ binder.op , resultType, input, weight, bias, stridesList,
583+ paddingList, dilationsList, transposed, outputPaddingList,
584+ cstGroup);
585+ return success ();
586+ });
587+ patterns.onOp (
588+ " ConvTranspose" , 11 ,
589+ [](OpBinder binder, ConversionPatternRewriter &rewriter) {
590+ std::string autoPad;
591+ if (binder.customOpNameStringAttr (autoPad, " auto_pad" , " NOTSET" ))
592+ return failure ();
593+ if (autoPad != " NOTSET" ) {
594+ // TODO: Add support for `auto_pad` != "NOTSET"
595+ return rewriter.notifyMatchFailure (
596+ binder.op , " unsupported conversion: auto_pad != NOTSET" );
597+ }
598+ SmallVector<int64_t > outputShape;
599+ if (binder.s64IntegerArrayAttr (outputShape, " output_shape" , {}))
600+ return failure ();
601+ if (outputShape.size ()) {
602+ // TODO: Add support for non-None output_shape value.
603+ return rewriter.notifyMatchFailure (
604+ binder.op ,
605+ " unsupported conversion: output_shape should be absent" );
606+ }
607+ Torch::ValueTensorType resultType;
608+ Value input, weight;
609+ int64_t group;
610+ if (binder.tensorOperands (input, weight) ||
611+ binder.s64IntegerAttr (group, " group" , 1 ) ||
612+ binder.tensorResultType (resultType))
613+ return failure ();
614+
615+ auto weightTensorType = weight.getType ().cast <Torch::ValueTensorType>();
616+ if (!weightTensorType || !weightTensorType.hasSizes ()) {
617+ return rewriter.notifyMatchFailure (
618+ binder.op , " Expected weight type having sizes" );
619+ }
620+ ArrayRef<int64_t > weightShape = weightTensorType.getSizes ();
621+ SmallVector<int64_t > kernelShape;
622+ if (binder.s64IntegerArrayAttr (kernelShape, " kernel_shape" , {}))
623+ return failure ();
624+ if (kernelShape.size ()) {
625+ if (kernelShape.size () != weightShape.size () - 2 ) {
626+ return rewriter.notifyMatchFailure (
627+ binder.op ,
628+ " unsupported conversion: kernel_shape list size should have "
629+ " number of values equal to weight_rank - 2" );
630+ } else {
631+ for (unsigned i = 0 ; i < kernelShape.size (); i++) {
632+ if (weightShape[i + 2 ] != kernelShape[i]) {
633+ return rewriter.notifyMatchFailure (
634+ binder.op , " unsupported conversion: kernel_shape value "
635+ " should be equal to the weight tensor shape" );
636+ }
637+ }
638+ }
639+ }
640+
641+ // Determine the rank of input tensor.
642+ std::optional<unsigned > maybeRank = Torch::getTensorRank (input);
643+ if (!maybeRank)
644+ return rewriter.notifyMatchFailure (binder.op ,
645+ " Unimplemented: unranked tensor" );
646+ unsigned rank = *maybeRank;
647+
648+ SmallVector<int64_t > padding, strides, dilations, outputPadding;
649+ SmallVector<int64_t > defaultPadding, defaultStrides, defaultDilations, defaultOutputPadding;
650+ for (unsigned i = 0 ; i < rank - 2 ; i++) {
651+ defaultPadding.push_back (0 );
652+ defaultStrides.push_back (1 );
653+ defaultDilations.push_back (1 );
654+ defaultOutputPadding.push_back (0 );
655+ }
656+ // Padding for the beginning and ending along each spatial axis, it can
657+ // take any value greater than or equal to 0. The value represent the
658+ // number of pixels added to the beginning and end part of the
659+ // corresponding axis. pads format should be as follow [x1_begin,
660+ // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added
661+ // at the beginning of axis i and xi_end, the number of pixels added at
662+ // the end of axis i.
663+ if (binder.s64IntegerArrayAttr (padding, " pads" , defaultPadding)) {
664+ return failure ();
665+ }
666+ if (padding.size () != rank - 2 && padding.size () != 2 * (rank - 2 )) {
667+ return rewriter.notifyMatchFailure (
668+ binder.op , " padding list size does not match the number of axes" );
669+ }
670+ if (binder.s64IntegerArrayAttr (dilations, " dilations" ,
671+ defaultDilations)) {
672+ return failure ();
673+ }
674+ if (dilations.size () != rank - 2 ) {
675+ return rewriter.notifyMatchFailure (
676+ binder.op ,
677+ " dilations list size does not match the number of axes" );
678+ }
679+ if (binder.s64IntegerArrayAttr (strides, " strides" , defaultStrides)) {
680+ return failure ();
681+ }
682+ if (strides.size () != rank - 2 ) {
683+ return rewriter.notifyMatchFailure (
684+ binder.op , " strides list size does not match the number of axes" );
685+ }
686+ if (binder.s64IntegerArrayAttr (outputPadding, " output_padding" ,
687+ defaultOutputPadding)) {
688+ return failure ();
689+ }
690+ if (outputPadding.size () != rank - 2 ) {
691+ return rewriter.notifyMatchFailure (
692+ binder.op ,
693+ " output_padding list size does not match the number of axes" );
694+ }
695+
696+ SmallVector<Value> cstPadding, cstStrides, cstDilations,
697+ cstOutputPadding;
698+ if (padding.size () != 2 * (rank - 2 )) {
699+ for (int64_t i : padding) {
700+ cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
701+ binder.getLoc (), rewriter.getI64IntegerAttr (i)));
702+ }
703+ } else {
704+ for (unsigned i = 0 ; i < padding.size () / 2 ; i++) {
705+ if (padding[i] != padding[i + (padding.size () / 2 )]) {
706+ // TODO: Add support for different padding values for the
707+ // beginning and ending along each spatial axis
708+ return rewriter.notifyMatchFailure (
709+ binder.op ,
710+ " unsupported conversion: padding values for the beginning "
711+ " and ending along each spatial axis must be equal" );
712+ }
713+ cstPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
714+ binder.getLoc (), rewriter.getI64IntegerAttr (padding[i])));
715+ }
716+ }
717+ for (int64_t i : dilations) {
718+ cstDilations.push_back (rewriter.create <Torch::ConstantIntOp>(
719+ binder.getLoc (), rewriter.getI64IntegerAttr (i)));
720+ }
721+ for (int64_t i : strides) {
722+ cstStrides.push_back (rewriter.create <Torch::ConstantIntOp>(
723+ binder.getLoc (), rewriter.getI64IntegerAttr (i)));
724+ }
725+ for (int64_t i : outputPadding) {
726+ cstOutputPadding.push_back (rewriter.create <Torch::ConstantIntOp>(
727+ binder.getLoc (), rewriter.getI64IntegerAttr (i)));
728+ }
729+
730+ Value paddingList = rewriter.create <Torch::PrimListConstructOp>(
731+ binder.getLoc (),
732+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
733+ cstPadding);
734+ Value dilationsList = rewriter.create <Torch::PrimListConstructOp>(
735+ binder.getLoc (),
736+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
737+ cstDilations);
738+ Value stridesList = rewriter.create <Torch::PrimListConstructOp>(
739+ binder.getLoc (),
740+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
741+ cstStrides);
742+ Value outputPaddingList = rewriter.create <Torch::PrimListConstructOp>(
743+ binder.getLoc (),
744+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
745+ cstOutputPadding);
746+ Value transposed =
747+ rewriter.create <Torch::ConstantBoolOp>(binder.getLoc (), true );
748+ Value bias;
749+ if (binder.op ->getNumOperands () == 3 ) {
750+ if (binder.tensorOperandAtIndex (bias, 2 )) {
751+ return failure ();
752+ }
753+ } else {
754+ bias = rewriter.create <Torch::ConstantNoneOp>(binder.getLoc ());
755+ }
756+ Value cstGroup = rewriter.create <Torch::ConstantIntOp>(
757+ binder.getLoc (), rewriter.getI64IntegerAttr (group));
758+
759+ rewriter.replaceOpWithNewOp <Torch::AtenConvolutionOp>(
760+ binder.op , resultType, input, weight, bias, stridesList,
761+ paddingList, dilationsList, transposed, outputPaddingList,
762+ cstGroup);
763+ return success ();
764+ });
429765 patterns.onOp (" Cos" , 7 ,
430766 [](OpBinder binder, ConversionPatternRewriter &rewriter) {
431767 Torch::ValueTensorType resultType;
0 commit comments