@@ -788,6 +788,244 @@ module attributes { transform.with_named_sequence } {
788788
789789// -----
790790
791+ func.func @arg_compare_tile_tensor (
792+ %input: tensor <?x?xf32 >,
793+ %outv: tensor <?xf32 >,
794+ %outi: tensor <?xi32 >
795+ ) -> (tensor <?xf32 >, tensor <?xi32 >) {
796+ %0:2 = iree_linalg_ext.arg_compare
797+ dimension (1 )
798+ ins (%input : tensor <?x?xf32 >)
799+ outs (%outv , %outi : tensor <?xf32 >, tensor <?xi32 >) {
800+ ^bb0 (%a: f32 , %b: f32 ):
801+ %cmp = arith.cmpf ogt , %a , %b : f32
802+ iree_linalg_ext.yield %cmp : i1
803+ } -> tensor <?xf32 >, tensor <?xi32 >
804+ return %0#0 , %0#1 : tensor <?xf32 >, tensor <?xi32 >
805+ }
806+
807+ module attributes { transform.with_named_sequence } {
808+ transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
809+ %0 = transform.structured.match ops {[" iree_linalg_ext.arg_compare" ]} in %module_op
810+ : (!transform.any_op ) -> !transform.any_op
811+ %1 , %loops = transform.structured.tile_using_for %0 tile_sizes [10 , 0 ]
812+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
813+ transform.yield
814+ }
815+ }
816+
817+ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
818+ // CHECK: func.func @arg_compare_tile_tensor
819+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
820+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]
821+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]
822+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
823+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
824+ // CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
825+ // CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
826+ // CHECK: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
827+ // CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[C10]] iter_args(%[[V0:.+]] = %[[ARG1]], %[[V1:.+]] = %[[ARG2]])
828+ // CHECK: %[[MIN:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D0]]]
829+ // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[MIN]], %[[D1]]] [1, 1]
830+ // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[V0]][%[[IV]]] [%[[MIN]]] [1]
831+ // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[V1]][%[[IV]]] [%[[MIN]]] [1]
832+ // CHECK: %[[CMP:.+]]:2 = iree_linalg_ext.arg_compare
833+ // CHECK-SAME: ins(%[[SLICE0]]
834+ // CHECK-SAME: outs(%[[SLICE1]], %[[SLICE2]]
835+ // CHECK: %[[INS0:.+]] = tensor.insert_slice %[[CMP]]#0 into %[[V0]][%[[IV]]] [%[[MIN]]] [1]
836+ // CHECK: %[[INS1:.+]] = tensor.insert_slice %[[CMP]]#1 into %[[V1]][%[[IV]]] [%[[MIN]]] [1]
837+ // CHECK: scf.yield %[[INS0]], %[[INS1]]
838+ // CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
839+
840+ // -----
841+
842+ func.func @arg_compare_tile_memref (
843+ %input: memref <?x?xf32 >,
844+ %outv: memref <?xf32 >,
845+ %outi: memref <?xi32 >
846+ ) {
847+ iree_linalg_ext.arg_compare
848+ dimension (1 )
849+ ins (%input : memref <?x?xf32 >)
850+ outs (%outv , %outi : memref <?xf32 >, memref <?xi32 >) {
851+ ^bb0 (%a: f32 , %b: f32 ):
852+ %cmp = arith.cmpf ogt , %a , %b : f32
853+ iree_linalg_ext.yield %cmp : i1
854+ }
855+ return
856+ }
857+
858+ module attributes { transform.with_named_sequence } {
859+ transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
860+ %0 = transform.structured.match ops {[" iree_linalg_ext.arg_compare" ]} in %module_op
861+ : (!transform.any_op ) -> !transform.any_op
862+ %1 , %loops = transform.structured.tile_using_for %0 tile_sizes [10 , 0 ]
863+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
864+ transform.yield
865+ }
866+ }
867+
868+ // CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
869+ // CHECK: func.func @arg_compare_tile_memref
870+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
871+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
872+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
873+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
874+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
875+ // CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
876+ // CHECK: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]]
877+ // CHECK: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]]
878+ // CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[C10]]
879+ // CHECK: %[[MIN:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D0]]]
880+ // CHECK: %[[SV0:.+]] = memref.subview %[[ARG0]][%[[IV]], 0] [%[[MIN]], %[[D1]]] [1, 1]
881+ // CHECK: %[[SV1:.+]] = memref.subview %[[ARG1]][%[[IV]]] [%[[MIN]]] [1]
882+ // CHECK: %[[SV2:.+]] = memref.subview %[[ARG2]][%[[IV]]] [%[[MIN]]] [1]
883+ // CHECK: iree_linalg_ext.arg_compare
884+ // CHECK-SAME: dimension(1)
885+ // CHECK-SAME: ins(%[[SV0]]
886+ // CHECK-SAME: outs(%[[SV1]], %[[SV2]]
887+ // CHECK: return
888+
889+ // -----
890+
891+ func.func @arg_compare_1d (%input: tensor <128 xf32 >) -> tensor <i32 > {
892+ %outv = tensor.empty () : tensor <f32 >
893+ %outi = tensor.empty () : tensor <i32 >
894+ %result:2 = iree_linalg_ext.arg_compare
895+ dimension (0 )
896+ ins (%input : tensor <128 xf32 >)
897+ outs (%outv , %outi : tensor <f32 >, tensor <i32 >) {
898+ ^bb0 (%a: f32 , %b: f32 ):
899+ %cmp = arith.cmpf ogt , %a , %b : f32
900+ iree_linalg_ext.yield %cmp : i1
901+ } -> tensor <f32 >, tensor <i32 >
902+ return %result#1 : tensor <i32 >
903+ }
904+
905+ module attributes { transform.with_named_sequence } {
906+ transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
907+ %0 = transform.structured.match ops {[" iree_linalg_ext.arg_compare" ]} in %module_op
908+ : (!transform.any_op ) -> !transform.any_op
909+ %1 = transform.structured.tile_using_for %0 tile_sizes [0 ]
910+ : (!transform.any_op ) -> (!transform.any_op )
911+ transform.yield
912+ }
913+ }
914+
915+ // CHECK: func.func @arg_compare_1d(
916+ // CHECK-SAME: %[[OPERAND:.+]]: tensor<128xf32>
917+ // CHECK: %[[ACCV:.+]] = tensor.empty() : tensor<f32>
918+ // CHECK: %[[ACCI:.+]] = tensor.empty() : tensor<i32>
919+ // CHECK: %[[RESULT:.+]]:2 = iree_linalg_ext.arg_compare
920+ // CHECK-SAME: ins(%[[OPERAND]] :
921+ // CHECK-SAME: outs(%[[ACCV]], %[[ACCI]] :
922+ // CHECK: return %[[RESULT]]#1
923+
924+ // -----
925+
926+ func.func @arg_compare_2d_dim0 (%input: tensor <16 x32 xf32 >) -> tensor <32 xi32 > {
927+ %outv = tensor.empty () : tensor <32 xf32 >
928+ %outi = tensor.empty () : tensor <32 xi32 >
929+ %result:2 = iree_linalg_ext.arg_compare
930+ dimension (0 )
931+ ins (%input : tensor <16 x32 xf32 >)
932+ outs (%outv , %outi : tensor <32 xf32 >, tensor <32 xi32 >) {
933+ ^bb0 (%a: f32 , %b: f32 ):
934+ %cmp = arith.cmpf ogt , %a , %b : f32
935+ iree_linalg_ext.yield %cmp : i1
936+ } -> tensor <32 xf32 >, tensor <32 xi32 >
937+ return %result#1 : tensor <32 xi32 >
938+ }
939+
940+ module attributes { transform.with_named_sequence } {
941+ transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
942+ %0 = transform.structured.match ops {[" iree_linalg_ext.arg_compare" ]} in %module_op
943+ : (!transform.any_op ) -> !transform.any_op
944+ // Only tile the non-reduction dimension: columns.
945+ %1 , %loops = transform.structured.tile_using_for %0 tile_sizes [0 , 20 ]
946+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
947+ transform.yield
948+ }
949+ }
950+
951+ // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (-d0 + 32, 20)>
952+ // CHECK: func.func @arg_compare_2d_dim0(
953+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]
954+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
955+ // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
956+ // CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
957+ // CHECK-DAG: %[[ACCV:.+]] = tensor.empty() : tensor<32xf32>
958+ // CHECK-DAG: %[[ACCI:.+]] = tensor.empty() : tensor<32xi32>
959+ // CHECK: %[[RESULT:.+]]:2 = scf.for %[[I:.+]] = %[[C0]] to %[[C32]] step %[[C20]]
960+ // CHECK-SAME: iter_args(%[[ARG2:.+]] = %[[ACCV]], %[[ARG3:.+]] = %[[ACCI]])
961+ // CHECK: %[[SIZE:.+]] = affine.min #[[MAP0]](%[[I]])
962+ // CHECK: %[[UPDATE_SLICE_IN:.+]] = tensor.extract_slice %[[ARG0]][0, %[[I]]] [16, %[[SIZE]]] [1, 1]
963+ // CHECK: %[[UPDATE_SLICE_OUTV:.+]] = tensor.extract_slice %[[ARG2]][%[[I]]] [%[[SIZE]]] [1]
964+ // CHECK: %[[UPDATE_SLICE_OUTI:.+]] = tensor.extract_slice %[[ARG3]][%[[I]]] [%[[SIZE]]] [1]
965+ // CHECK: %[[ARGCMP_TILE:.+]]:2 = iree_linalg_ext.arg_compare
966+ // CHECK-SAME: dimension(0)
967+ // CHECK-SAME: ins(%[[UPDATE_SLICE_IN]]
968+ // CHECK-SAME: outs(%[[UPDATE_SLICE_OUTV]], %[[UPDATE_SLICE_OUTI]]
969+ // CHECK: %[[ACCV_YIELD:.+]] = tensor.insert_slice %[[ARGCMP_TILE]]#0 into %[[ARG2]][%[[I]]] [%[[SIZE]]] [1]
970+ // CHECK: %[[ACCI_YIELD:.+]] = tensor.insert_slice %[[ARGCMP_TILE]]#1 into %[[ARG3]][%[[I]]] [%[[SIZE]]] [1]
971+ // CHECK: scf.yield %[[ACCV_YIELD]], %[[ACCI_YIELD]] : tensor<32xf32>, tensor<32xi32>
972+ // CHECK: return %[[RESULT]]#1
973+
974+ // -----
975+
976+ func.func @arg_compare_with_base (
977+ %input : tensor <2 x6 xf32 >,
978+ %outv : tensor <2 xf32 >,
979+ %outi : tensor <2 xindex >,
980+ %base : index
981+ ) -> (tensor <2 xf32 >, tensor <2 xindex >) {
982+ %0:2 = iree_linalg_ext.arg_compare
983+ dimension (1 )
984+ ins (%input : tensor <2 x6 xf32 >)
985+ outs (%outv , %outi : tensor <2 xf32 >, tensor <2 xindex >)
986+ index_base (%base : index ) {
987+ ^bb0 (%a: f32 , %b: f32 ):
988+ %cmp = arith.cmpf ogt , %a , %b : f32
989+ iree_linalg_ext.yield %cmp : i1
990+ } -> tensor <2 xf32 >, tensor <2 xindex >
991+ return %0#0 , %0#1 : tensor <2 xf32 >, tensor <2 xindex >
992+ }
993+
994+ module attributes { transform.with_named_sequence } {
995+ transform.named_sequence @__transform_main (%module_op: !transform.any_op {transform.readonly }) {
996+ %0 = transform.structured.match ops {[" iree_linalg_ext.arg_compare" ]} in %module_op
997+ : (!transform.any_op ) -> !transform.any_op
998+ %1 , %loops = transform.structured.tile_using_for %0 tile_sizes [1 , 0 ]
999+ : (!transform.any_op ) -> (!transform.any_op , !transform.any_op )
1000+ transform.yield
1001+ }
1002+ }
1003+
1004+ // CHECK-LABEL: func.func @arg_compare_with_base(
1005+ // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<2x6xf32>
1006+ // CHECK-SAME: %[[OUTV:[a-zA-Z0-9_]+]]: tensor<2xf32>
1007+ // CHECK-SAME: %[[OUTI:[a-zA-Z0-9_]+]]: tensor<2xindex>
1008+ // CHECK-SAME: %[[BASE:[a-zA-Z0-9_]+]]: index
1009+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1010+ // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
1011+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1012+ // CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[C2]] step %[[C1]]
1013+ // CHECK-SAME: iter_args(%[[VARG:.+]] = %[[OUTV]], %[[IARG:.+]] = %[[OUTI]])
1014+ // CHECK: %[[SLICE_IN:.+]] = tensor.extract_slice %[[INPUT]][%[[IV]], 0] [1, 6] [1, 1]
1015+ // CHECK: %[[SLICE_OUTV:.+]] = tensor.extract_slice %[[VARG]][%[[IV]]] [1] [1]
1016+ // CHECK: %[[SLICE_OUTI:.+]] = tensor.extract_slice %[[IARG]][%[[IV]]] [1] [1]
1017+ // CHECK: %[[ARGCMP:.+]]:2 = iree_linalg_ext.arg_compare
1018+ // CHECK-SAME: dimension(1)
1019+ // CHECK-SAME: ins(%[[SLICE_IN]]
1020+ // CHECK-SAME: outs(%[[SLICE_OUTV]], %[[SLICE_OUTI]]
1021+ // CHECK-SAME: index_base(%[[BASE]]
1022+ // CHECK: %[[INS_OUTV:.+]] = tensor.insert_slice %[[ARGCMP]]#0 into %[[VARG]][%[[IV]]] [1] [1]
1023+ // CHECK: %[[INS_OUTI:.+]] = tensor.insert_slice %[[ARGCMP]]#1 into %[[IARG]][%[[IV]]] [1] [1]
1024+ // CHECK: scf.yield %[[INS_OUTV]], %[[INS_OUTI]]
1025+ // CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
1026+
1027+ // -----
1028+
7911029func.func @im2col (%arg0: tensor <2 x34 x34 x640 xf32 >) -> tensor <2 x1024 x5760 xf32 > {
7921030 %0 = tensor.empty () : tensor <2 x1024 x5760 xf32 >
7931031 %1 = iree_linalg_ext.im2col strides = [1 , 1 ] dilations = [1 , 1 ] kernel_size = [3 , 3 ]
0 commit comments