Skip to content

Commit 927c514

Browse files
committed
added 2 more test-cases
1 parent 2257cf2 commit 927c514

File tree

1 file changed

+175
-34
lines changed

1 file changed

+175
-34
lines changed
Lines changed: 175 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,189 @@
11
// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s
22

33
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
4-
#mapTransposeB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
4+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
55
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
6-
7-
func.func @transpose_matrix_no_conversion_to_fma(%arg0: memref<16x32x128xf32>, %arg1: memref<16x128x64xf32>, %arg2: memref<32x64xf32>) {
8-
%cst = arith.constant 0.000000e+00 : f32
9-
%c1 = arith.constant 1 : index
10-
%c16 = arith.constant 16 : index
11-
%c64 = arith.constant 64 : index
12-
%c128 = arith.constant 128 : index
13-
%c4 = arith.constant 4 : index
14-
%c32 = arith.constant 32 : index
15-
%c0 = arith.constant 0 : index
16-
17-
scf.for %arg5 = %c0 to %c32 step %c4 {
18-
scf.for %arg6 = %c0 to %c128 step %c64 {
19-
%subview_2 = memref.subview %arg2[%arg5, %arg6] [4, 64] [1, 1] : memref<32x64xf32> to memref<4x64xf32, strided<[64, 1], offset: ?>>
20-
%2 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
21-
%con = scf.for %arg7 = %c0 to %c16 step %c1 iter_args(%argcon = %2) -> vector<4x64xf32> {
22-
%con1 = scf.for %arg8 = %c0 to %c64 step %c1 iter_args(%argcon1 = %argcon) -> vector<4x64xf32> {
23-
%subview_3 = memref.subview %arg0[%arg7, %arg5, %arg8] [1, 4, 1] [1, 1, 1] : memref<16x32x128xf32> to memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>
24-
%subview_4 = memref.subview %arg1[%arg7, %arg8, %arg6] [1, 1, 64] [1, 1, 1] : memref<16x128x64xf32> to memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>
25-
%0 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>, vector<1x4x1xf32>
26-
%1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>, in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>, vector<1x64x1xf32>
27-
%3 = vector.contract {indexing_maps = [#map, #mapTransposeB, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %argcon1 : vector<1x4x1xf32>, vector<1x64x1xf32> into vector<4x64xf32>
28-
scf.yield %3 : vector<4x64xf32>
6+
memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
7+
func.func @lower_contract_to_fma(%arg0: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
8+
%cst = arith.constant 0.000000e+00 : f32
9+
%cst_0 = arith.constant dense<0.000000e+00> : vector<32x64xf32>
10+
%c1 = arith.constant 1 : index
11+
%c24 = arith.constant 24 : index
12+
%c64 = arith.constant 64 : index
13+
%c4 = arith.constant 4 : index
14+
%c32 = arith.constant 32 : index
15+
%c0 = arith.constant 0 : index
16+
%0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
17+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
18+
scf.forall (%arg1, %arg2) in (8, 24) {
19+
%subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
20+
vector.transfer_write %cst_0, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
21+
%subview_1 = memref.subview %arg0[%arg1, 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
22+
scf.for %arg3 = %c0 to %c32 step %c4 {
23+
scf.for %arg4 = %c0 to %c64 step %c64 {
24+
%subview_2 = memref.subview %subview[%arg3, %arg4] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
25+
%1 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
26+
%2 = scf.for %arg5 = %c0 to %c24 step %c1 iter_args(%arg6 = %1) -> (vector<4x64xf32>) {
27+
%3 = scf.for %arg7 = %c0 to %c64 step %c1 iter_args(%arg8 = %arg6) -> (vector<4x64xf32>) {
28+
%subview_3 = memref.subview %subview_1[%arg5, %arg3, %arg7] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
29+
%subview_4 = memref.subview %0[%arg5, %arg7, %arg4] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
30+
%4 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>, vector<1x4x1xf32>
31+
%5 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<1x1x64xf32>
32+
%6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg8 : vector<1x4x1xf32>, vector<1x1x64xf32> into vector<4x64xf32>
33+
scf.yield %6 : vector<4x64xf32>
2934
}
30-
scf.yield %con1 : vector<4x64xf32>
35+
scf.yield %3 : vector<4x64xf32>
3136
}
32-
vector.transfer_write %con, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
37+
vector.transfer_write %2, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
3338
}
3439
}
35-
return
3640
}
41+
return %alloc : memref<8x24x32x64xf32>
42+
}
43+
44+
// CHECK-LABEL: memref.global "private" constant @__constant_24x64x64xf32 : memref<24x64x64xf32> = dense<1.000000e+00> {alignment = 64 : i64}
45+
// CHECK-LABEL: func.func @lower_contract_to_fma(
46+
// CHECK-SAME: %[[VAL_0:.*]]: memref<8x24x32x64xf32>) -> memref<8x24x32x64xf32> {
47+
// CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
48+
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
49+
// CHECK: %[[VAL_3:.*]] = arith.constant dense<0.000000e+00> : vector<32x64xf32>
50+
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
51+
// CHECK: %[[VAL_5:.*]] = arith.constant 24 : index
52+
// CHECK: %[[VAL_6:.*]] = arith.constant 64 : index
53+
// CHECK: %[[VAL_7:.*]] = arith.constant 4 : index
54+
// CHECK: %[[VAL_8:.*]] = arith.constant 32 : index
55+
// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
56+
// CHECK: %[[VAL_10:.*]] = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32>
57+
// CHECK: %[[VAL_11:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32>
58+
// CHECK: scf.forall (%[[VAL_12:.*]], %[[VAL_13:.*]]) in (8, 24) {
59+
// CHECK: %[[VAL_14:.*]] = memref.subview %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], 0, 0] [1, 1, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>>
60+
// CHECK: vector.transfer_write %[[VAL_3]], %[[VAL_14]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<32x64xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>
61+
// CHECK: %[[VAL_15:.*]] = memref.subview %[[VAL_0]]{{\[}}%[[VAL_12]], 0, 0, 0] [1, 24, 32, 64] [1, 1, 1, 1] : memref<8x24x32x64xf32> to memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>>
62+
// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_9]] to %[[VAL_8]] step %[[VAL_7]] {
63+
// CHECK: scf.for %[[VAL_17:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_6]] {
64+
// CHECK: %[[VAL_18:.*]] = memref.subview %[[VAL_14]]{{\[}}%[[VAL_16]], %[[VAL_17]]] [4, 64] [1, 1] : memref<32x64xf32, strided<[64, 1], offset: ?>> to memref<4x64xf32, strided<[64, 1], offset: ?>>
65+
// CHECK: %[[VAL_19:.*]] = memref.subview %[[VAL_18]][0, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
66+
// CHECK: %[[VAL_20:.*]] = memref.subview %[[VAL_18]][1, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
67+
// CHECK: %[[VAL_21:.*]] = memref.subview %[[VAL_18]][2, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
68+
// CHECK: %[[VAL_22:.*]] = memref.subview %[[VAL_18]][3, 0] [1, 64] [1, 1] : memref<4x64xf32, strided<[64, 1], offset: ?>> to memref<1x64xf32, strided<[64, 1], offset: ?>>
69+
// CHECK: %[[VAL_23:.*]] = vector.load %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
70+
// CHECK: %[[VAL_24:.*]] = vector.load %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
71+
// CHECK: %[[VAL_25:.*]] = vector.load %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
72+
// CHECK: %[[VAL_26:.*]] = vector.load %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
73+
// CHECK: %[[VAL_27:.*]]:4 = scf.for %[[VAL_28:.*]] = %[[VAL_9]] to %[[VAL_5]] step %[[VAL_4]] iter_args(%[[VAL_29:.*]] = %[[VAL_23]], %[[VAL_30:.*]] = %[[VAL_24]], %[[VAL_31:.*]] = %[[VAL_25]], %[[VAL_32:.*]] = %[[VAL_26]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) {
74+
// CHECK: %[[VAL_33:.*]]:4 = scf.for %[[VAL_34:.*]] = %[[VAL_9]] to %[[VAL_6]] step %[[VAL_4]] iter_args(%[[VAL_35:.*]] = %[[VAL_29]], %[[VAL_36:.*]] = %[[VAL_30]], %[[VAL_37:.*]] = %[[VAL_31]], %[[VAL_38:.*]] = %[[VAL_32]]) -> (vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>) {
75+
// CHECK: %[[VAL_39:.*]] = memref.subview %[[VAL_15]]{{\[}}%[[VAL_28]], %[[VAL_16]], %[[VAL_34]]] [1, 4, 1] [1, 1, 1] : memref<24x32x64xf32, strided<[2048, 64, 1], offset: ?>> to memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
76+
// CHECK: %[[VAL_40:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
77+
// CHECK: %[[VAL_41:.*]] = vector.broadcast %[[VAL_40]] : f32 to vector<64xf32>
78+
// CHECK: %[[VAL_42:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_4]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
79+
// CHECK: %[[VAL_43:.*]] = vector.broadcast %[[VAL_42]] : f32 to vector<64xf32>
80+
// CHECK: %[[VAL_44:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_2]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
81+
// CHECK: %[[VAL_45:.*]] = vector.broadcast %[[VAL_44]] : f32 to vector<64xf32>
82+
// CHECK: %[[VAL_46:.*]] = memref.load %[[VAL_39]]{{\[}}%[[VAL_9]], %[[VAL_1]], %[[VAL_9]]] : memref<1x4x1xf32, strided<[2048, 64, 1], offset: ?>>
83+
// CHECK: %[[VAL_47:.*]] = vector.broadcast %[[VAL_46]] : f32 to vector<64xf32>
84+
// CHECK: %[[VAL_48:.*]] = memref.subview %[[VAL_10]]{{\[}}%[[VAL_28]], %[[VAL_34]], %[[VAL_17]]] [1, 1, 64] [1, 1, 1] : memref<24x64x64xf32> to memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>
85+
// CHECK: %[[VAL_49:.*]] = vector.load %[[VAL_48]]{{\[}}%[[VAL_9]], %[[VAL_9]], %[[VAL_9]]] : memref<1x1x64xf32, strided<[4096, 64, 1], offset: ?>>, vector<64xf32>
86+
// CHECK: %[[VAL_50:.*]] = vector.fma %[[VAL_41]], %[[VAL_49]], %[[VAL_35]] : vector<64xf32>
87+
// CHECK: %[[VAL_51:.*]] = vector.fma %[[VAL_43]], %[[VAL_49]], %[[VAL_36]] : vector<64xf32>
88+
// CHECK: %[[VAL_52:.*]] = vector.fma %[[VAL_45]], %[[VAL_49]], %[[VAL_37]] : vector<64xf32>
89+
// CHECK: %[[VAL_53:.*]] = vector.fma %[[VAL_47]], %[[VAL_49]], %[[VAL_38]] : vector<64xf32>
90+
// CHECK: scf.yield %[[VAL_50]], %[[VAL_51]], %[[VAL_52]], %[[VAL_53]] : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>
91+
// CHECK: }
92+
// CHECK: scf.yield %[[VAL_54:.*]]#0, %[[VAL_54]]#1, %[[VAL_54]]#2, %[[VAL_54]]#3 : vector<64xf32>, vector<64xf32>, vector<64xf32>, vector<64xf32>
93+
// CHECK: }
94+
// CHECK: vector.store %[[VAL_55:.*]]#0, %[[VAL_19]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
95+
// CHECK: vector.store %[[VAL_55]]#1, %[[VAL_20]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
96+
// CHECK: vector.store %[[VAL_55]]#2, %[[VAL_21]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
97+
// CHECK: vector.store %[[VAL_55]]#3, %[[VAL_22]]{{\[}}%[[VAL_9]], %[[VAL_9]]] : memref<1x64xf32, strided<[64, 1], offset: ?>>, vector<64xf32>
98+
// CHECK: }
99+
// CHECK: }
100+
// CHECK: }
101+
// CHECK: return %[[VAL_11]] : memref<8x24x32x64xf32>
102+
// CHECK: }
103+
104+
module attributes {transform.with_named_sequence} {
105+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
106+
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
107+
transform.apply_patterns to %0 {
108+
transform.apply_patterns.vector.contract_to_fma
109+
} : !transform.any_op
110+
transform.yield
111+
}
112+
}
113+
114+
//-----
115+
116+
#mapA = affine_map<(d0, d1, d2) -> (d0, d2)>
117+
#mapB = affine_map<(d0, d1, d2) -> (d2, d1)>
118+
#mapC = affine_map<(d0, d1, d2) -> (d0, d1)>
119+
func.func @matmul_without_iterarg_accumulator_so_no_lowering_to_fma(%arg0: tensor<4x1xf32>, %arg1: tensor<1x64xf32>, %arg2: tensor<4x64xf32>) -> tensor<4x64xf32> {
120+
%c0 = arith.constant 0 : index
121+
%cst = arith.constant 0.000000e+00 : f32
122+
%0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x1xf32>, vector<4x1xf32>
123+
%1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x64xf32>, vector<1x64xf32>
124+
%2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x64xf32>, vector<4x64xf32>
125+
%3 = vector.contract {indexing_maps = [#mapA, #mapB, #mapC], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %2 : vector<4x1xf32>, vector<1x64xf32> into vector<4x64xf32>
126+
%4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, tensor<4x64xf32>
127+
return %4 : tensor<4x64xf32>
128+
}
37129

38130
// CHECK-NOT: vector.fma
39131

40-
module attributes {transform.with_named_sequence} {
41-
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
42-
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
43-
transform.apply_patterns to %0 {
44-
transform.apply_patterns.vector.contract_to_fma
45-
} : !transform.any_op
46-
transform.yield
132+
module attributes {transform.with_named_sequence} {
133+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
134+
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
135+
transform.apply_patterns to %0 {
136+
transform.apply_patterns.vector.contract_to_fma
137+
} : !transform.any_op
138+
transform.yield
139+
}
140+
}
141+
142+
143+
// -----
144+
145+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
146+
#mapTransposeB = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
147+
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
148+
func.func @transpose_matrix_no_lowering_to_fma(%arg0: memref<16x32x128xf32>, %arg1: memref<16x128x64xf32>, %arg2: memref<32x64xf32>) {
149+
%cst = arith.constant 0.000000e+00 : f32
150+
%c1 = arith.constant 1 : index
151+
%c16 = arith.constant 16 : index
152+
%c64 = arith.constant 64 : index
153+
%c128 = arith.constant 128 : index
154+
%c4 = arith.constant 4 : index
155+
%c32 = arith.constant 32 : index
156+
%c0 = arith.constant 0 : index
157+
158+
scf.for %arg5 = %c0 to %c32 step %c4 {
159+
scf.for %arg6 = %c0 to %c128 step %c64 {
160+
%subview_2 = memref.subview %arg2[%arg5, %arg6] [4, 64] [1, 1] : memref<32x64xf32> to memref<4x64xf32, strided<[64, 1], offset: ?>>
161+
%2 = vector.transfer_read %subview_2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<4x64xf32, strided<[64, 1], offset: ?>>, vector<4x64xf32>
162+
%con = scf.for %arg7 = %c0 to %c16 step %c1 iter_args(%argcon = %2) -> vector<4x64xf32> {
163+
%con1 = scf.for %arg8 = %c0 to %c64 step %c1 iter_args(%argcon1 = %argcon) -> vector<4x64xf32> {
164+
%subview_3 = memref.subview %arg0[%arg7, %arg5, %arg8] [1, 4, 1] [1, 1, 1] : memref<16x32x128xf32> to memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>
165+
%subview_4 = memref.subview %arg1[%arg7, %arg8, %arg6] [1, 1, 64] [1, 1, 1] : memref<16x128x64xf32> to memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>
166+
%0 = vector.transfer_read %subview_3[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x4x1xf32, strided<[4096, 128, 1], offset: ?>>, vector<1x4x1xf32>
167+
%1 = vector.transfer_read %subview_4[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>, in_bounds = [true, true, true]} : memref<1x1x64xf32, strided<[8192, 64, 1], offset: ?>>, vector<1x64x1xf32>
168+
%3 = vector.contract {indexing_maps = [#map, #mapTransposeB, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %argcon1 : vector<1x4x1xf32>, vector<1x64x1xf32> into vector<4x64xf32>
169+
scf.yield %3 : vector<4x64xf32>
170+
}
171+
scf.yield %con1 : vector<4x64xf32>
172+
}
173+
vector.transfer_write %con, %subview_2[%c0, %c0] {in_bounds = [true, true]} : vector<4x64xf32>, memref<4x64xf32, strided<[64, 1], offset: ?>>
47174
}
48175
}
176+
return
177+
}
178+
179+
// CHECK-NOT: vector.fma
180+
181+
module attributes {transform.with_named_sequence} {
182+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
183+
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
184+
transform.apply_patterns to %0 {
185+
transform.apply_patterns.vector.contract_to_fma
186+
} : !transform.any_op
187+
transform.yield
188+
}
189+
}

0 commit comments

Comments
 (0)