@@ -699,6 +699,7 @@ func.func @illegal_im2col_strides(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x10
699699 %1 = iree_linalg_ext.im2col strides = [1 ] dilations = [1 , 1 ] kernel_size = [3 , 3 ]
700700 m_offset = [0 ] * [1 ] k_offset = [0 ] * [1 ]
701701 batch_pos = [0 ] m_pos = [1 , 2 ] k_pos = [3 ]
702+ input_k_perm = [0 , 1 , 2 ]
702703 ins (%arg0 : tensor <2 x34 x34 x640 xf32 >)
703704 outs (%0 : tensor <2 x1024 x5760 xf32 >) -> tensor <2 x1024 x5760 xf32 >
704705 return %1 : tensor <2 x1024 x5760 xf32 >
@@ -712,6 +713,7 @@ func.func @illegal_im2col_dilations(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x
712713 %1 = iree_linalg_ext.im2col strides = [1 , 1 ] dilations = [1 , 1 , 1 ] kernel_size = [3 , 3 ]
713714 m_offset = [0 ] * [1 ] k_offset = [0 ] * [1 ]
714715 batch_pos = [0 ] m_pos = [1 , 2 ] k_pos = [3 ]
716+ input_k_perm = [0 , 1 , 2 ]
715717 ins (%arg0 : tensor <2 x34 x34 x640 xf32 >)
716718 outs (%0 : tensor <2 x1024 x5760 xf32 >) -> tensor <2 x1024 x5760 xf32 >
717719 return %1 : tensor <2 x1024 x5760 xf32 >
@@ -725,6 +727,7 @@ func.func @illegal_im2col_kernel_size(%arg0: tensor<2x34x34x640xf32>) -> tensor<
725727 %1 = iree_linalg_ext.im2col strides = [1 , 1 ] dilations = [1 , 1 ] kernel_size = [3 ]
726728 m_offset = [0 ] * [1 ] k_offset = [0 ] * [1 ]
727729 batch_pos = [0 ] m_pos = [1 , 2 ] k_pos = [3 ]
730+ input_k_perm = [0 , 1 , 2 ]
728731 ins (%arg0 : tensor <2 x34 x34 x640 xf32 >)
729732 outs (%0 : tensor <2 x1024 x5760 xf32 >) -> tensor <2 x1024 x5760 xf32 >
730733 return %1 : tensor <2 x1024 x5760 xf32 >
@@ -738,6 +741,7 @@ func.func @illegal_im2col_m_offset(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1
738741 %1 = iree_linalg_ext.im2col strides = [1 , 1 ] dilations = [1 , 1 ] kernel_size = [3 , 3 ]
739742 m_offset = [0 , 0 ] * [1 ] k_offset = [0 ] * [1 ]
740743 batch_pos = [0 ] m_pos = [1 , 2 ] k_pos = [3 ]
744+ input_k_perm = [0 , 1 , 2 ]
741745 ins (%arg0 : tensor <2 x34 x34 x640 xf32 >)
742746 outs (%0 : tensor <2 x1024 x5760 xf32 >) -> tensor <2 x1024 x5760 xf32 >
743747 return %1 : tensor <2 x1024 x5760 xf32 >
@@ -751,6 +755,7 @@ func.func @illegal_im2col_k_offset(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x1
751755 %1 = iree_linalg_ext.im2col strides = [1 , 1 ] dilations = [1 , 1 ] kernel_size = [3 , 3 ]
752756 m_offset = [0 ] * [1 ] k_offset = [0 , 0 ] * [1 ]
753757 batch_pos = [0 ] m_pos = [1 , 2 ] k_pos = [3 ]
758+ input_k_perm = [0 , 1 , 2 ]
754759 ins (%arg0 : tensor <2 x34 x34 x640 xf32 >)
755760 outs (%0 : tensor <2 x1024 x5760 xf32 >) -> tensor <2 x1024 x5760 xf32 >
756761 return %1 : tensor <2 x1024 x5760 xf32 >
@@ -764,6 +769,7 @@ func.func @illegal_im2col_m_strides(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x
764769 %1 = iree_linalg_ext.im2col strides = [1 , 1 ] dilations = [1 , 1 ] kernel_size = [3 , 3 ]
765770 m_offset = [0 ] * [0 ] k_offset = [0 ] * [1 ]
766771 batch_pos = [0 ] m_pos = [1 , 2 ] k_pos = [3 ]
772+ input_k_perm = [0 , 1 , 2 ]
767773 ins (%arg0 : tensor <2 x34 x34 x640 xf32 >)
768774 outs (%0 : tensor <2 x1024 x5760 xf32 >) -> tensor <2 x1024 x5760 xf32 >
769775 return %1 : tensor <2 x1024 x5760 xf32 >
@@ -777,6 +783,7 @@ func.func @illegal_im2col_k_strides(%arg0: tensor<2x34x34x640xf32>) -> tensor<2x
777783 %1 = iree_linalg_ext.im2col strides = [1 , 1 ] dilations = [1 , 1 ] kernel_size = [3 , 3 ]
778784 m_offset = [0 ] * [1 ] k_offset = [0 ] * [2 ]
779785 batch_pos = [0 ] m_pos = [1 , 2 ] k_pos = [3 ]
786+ input_k_perm = [0 , 1 , 2 ]
780787 ins (%arg0 : tensor <2 x34 x34 x640 xf32 >)
781788 outs (%0 : tensor <2 x1024 x5760 xf32 >) -> tensor <2 x1024 x5760 xf32 >
782789 return %1 : tensor <2 x1024 x5760 xf32 >
@@ -790,6 +797,7 @@ func.func @illegal_im2col_input_rank(%arg0: tensor<1x2x34x34x640xf32>) -> tensor
790797 %1 = iree_linalg_ext.im2col strides = [1 , 1 ] dilations = [1 , 1 ] kernel_size = [3 , 3 ]
791798 m_offset = [0 ] * [1 ] k_offset = [0 ] * [1 ]
792799 batch_pos = [0 ] m_pos = [1 , 2 ] k_pos = [3 ]
800+ input_k_perm = [0 , 1 , 2 ]
793801 ins (%arg0 : tensor <1 x2 x34 x34 x640 xf32 >)
794802 outs (%0 : tensor <2 x1024 x5760 xf32 >) -> tensor <2 x1024 x5760 xf32 >
795803 return %1 : tensor <2 x1024 x5760 xf32 >
@@ -803,13 +811,42 @@ func.func @illegal_im2col_output_rank(%arg0: tensor<2x34x34x640xf32>) -> tensor<
803811 %1 = iree_linalg_ext.im2col strides = [1 , 1 ] dilations = [1 , 1 ] kernel_size = [3 , 3 ]
804812 m_offset = [0 ] * [1 ] k_offset = [0 ] * [1 ]
805813 batch_pos = [0 ] m_pos = [1 , 2 ] k_pos = [3 ]
814+ input_k_perm = [0 , 1 , 2 ]
806815 ins (%arg0 : tensor <2 x34 x34 x640 xf32 >)
807816 outs (%0 : tensor <2 x1024 x9 x640 xf32 >) -> tensor <2 x1024 x9 x640 xf32 >
808817 return %1 : tensor <2 x1024 x9 x640 xf32 >
809818}
810819
811820// -----
812821
822+ func.func @illegal_im2col_perm_num (%arg0: tensor <2 x34 x34 x640 xf32 >) -> tensor <2 x1024 x5760 xf32 > {
823+ %0 = tensor.empty () : tensor <2 x1024 x5760 xf32 >
824+ // expected-error @+1 {{expected input_k_perm size (2) to match the number of shared dimensions (m_Pos + k_pos = 3)}}
825+ %1 = iree_linalg_ext.im2col strides = [1 , 1 ] dilations = [1 , 1 ] kernel_size = [3 , 3 ]
826+ m_offset = [0 ] * [1 ] k_offset = [0 ] * [1 ]
827+ batch_pos = [0 ] m_pos = [1 , 2 ] k_pos = [3 ]
828+ input_k_perm = [0 , 1 ]
829+ ins (%arg0 : tensor <2 x34 x34 x640 xf32 >)
830+ outs (%0 : tensor <2 x1024 x5760 xf32 >) -> tensor <2 x1024 x5760 xf32 >
831+ return %1 : tensor <2 x1024 x5760 xf32 >
832+ }
833+
834+ // -----
835+
836+ func.func @illegal_im2col_perm_value (%arg0: tensor <2 x34 x34 x640 xf32 >) -> tensor <2 x1024 x5760 xf32 > {
837+ %0 = tensor.empty () : tensor <2 x1024 x5760 xf32 >
838+ // expected-error @+1 {{expected input_k_perm to be a permutation of [0, 3)}}
839+ %1 = iree_linalg_ext.im2col strides = [1 , 1 ] dilations = [1 , 1 ] kernel_size = [3 , 3 ]
840+ m_offset = [0 ] * [1 ] k_offset = [0 ] * [1 ]
841+ batch_pos = [0 ] m_pos = [1 , 2 ] k_pos = [3 ]
842+ input_k_perm = [1 , 2 , 3 ]
843+ ins (%arg0 : tensor <2 x34 x34 x640 xf32 >)
844+ outs (%0 : tensor <2 x1024 x5760 xf32 >) -> tensor <2 x1024 x5760 xf32 >
845+ return %1 : tensor <2 x1024 x5760 xf32 >
846+ }
847+
848+ // -----
849+
813850func.func @illegal_winograd_input_shape (%arg0: tensor <1 x10 x10 x32 xf32 >) -> tensor <8 x8 x1 x6 x6 x32 xf32 > {
814851 %0 = tensor.empty () : tensor <8 x8 x1 x6 x6 x32 xf32 >
815852 // expected-error @+1 {{incompatible output shape}}
0 commit comments