@@ -232,7 +232,6 @@ builtin.module attributes { transform.with_named_sequence } {
232232}
233233
234234// CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mm_mnbatch
235- // CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00>
236235// CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x32xf32> -> vector<2x1x4x1x4x1xf32>
237236// CHECK: %[[A_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x8xf16> -> vector<2x1x1x1x1x4xf16>
238237// CHECK: %[[C_SLICE0:.+]] = vector.extract %[[C_SIMT]][0, 0] : vector<4x1x4x1xf32> from vector<2x1x4x1x4x1xf32
@@ -241,15 +240,18 @@ builtin.module attributes { transform.with_named_sequence } {
241240// CHECK: %[[C0_CAST:.+]] = vector.shape_cast %[[C_SLICE0]] : vector<4x1x4x1xf32> to vector<16xf32>
242241// CHECK: %[[MFMA0:.+]] = amdgpu.mfma %[[A0_CAST]] * %{{.+}} + %[[C0_CAST]]
243242// CHECK: %[[R0_CAST:.+]] = vector.shape_cast %[[MFMA0]] : vector<16xf32> to vector<4x1x4x1xf32>
244- // CHECK: %[[C0_INS:.+]] = vector.insert %[[R0_CAST]], %[[INIT]] [0, 0] : vector<4x1x4x1xf32> into vector<2x1x4x1x4x1xf32>
245243// CHECK: %[[C_SLICE1:.+]] = vector.extract %[[C_SIMT]][1, 0] : vector<4x1x4x1xf32> from vector<2x1x4x1x4x1xf32>
246244// CHECK: %[[A_SLICE1:.+]] = vector.extract %[[A_SIMT]][1, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
247245// CHECK: %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
248246// CHECK: %[[C1_CAST:.+]] = vector.shape_cast %[[C_SLICE1]] : vector<4x1x4x1xf32> to vector<16xf32>
249247// CHECK: %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %{{.+}} + %[[C1_CAST]]
250248// CHECK: %[[R1_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<16xf32> to vector<4x1x4x1xf32>
251- // CHECK: %[[C1_INS:.+]] = vector.insert %[[R1_CAST]], %[[C0_INS]] [1, 0] : vector<4x1x4x1xf32> into vector<2x1x4x1x4x1xf32>
252- // CHECK: %[[R:.+]] = iree_vector_ext.to_simd %[[C1_INS]] : vector<2x1x4x1x4x1xf32> -> vector<64x32xf32>
249+ // CHECK: %[[R0:.+]]:16 = vector.to_elements %[[R0_CAST]] : vector<4x1x4x1xf32>
250+ // CHECK: %[[R1:.+]]:16 = vector.to_elements %[[R1_CAST]] : vector<4x1x4x1xf32>
251+ // CHECK: %[[INS:.+]] = vector.from_elements
252+ // CHECK-SAME: %[[R0]]#0, %[[R0]]#1, %[[R0]]#2, %[[R0]]#3, %[[R0]]#4, %[[R0]]#5, %[[R0]]#6, %[[R0]]#7, %[[R0]]#8, %[[R0]]#9, %[[R0]]#10, %[[R0]]#11, %[[R0]]#12, %[[R0]]#13, %[[R0]]#14, %[[R0]]#15
253+ // CHECK-SAME: %[[R1]]#0, %[[R1]]#1, %[[R1]]#2, %[[R1]]#3, %[[R1]]#4, %[[R1]]#5, %[[R1]]#6, %[[R1]]#7, %[[R1]]#8, %[[R1]]#9, %[[R1]]#10, %[[R1]]#11, %[[R1]]#12, %[[R1]]#13, %[[R1]]#14, %[[R1]]#15
254+ // CHECK: %[[R:.+]] = iree_vector_ext.to_simd %[[INS]] : vector<2x1x4x1x4x1xf32> -> vector<64x32xf32>
253255// CHECK: return %[[R]]
254256
255257// -----
@@ -403,28 +405,23 @@ builtin.module attributes { transform.with_named_sequence } {
403405 }
404406}
405407
406- // CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mm_mnbatch_order
407- // CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<2x3x4x1x4x1xf32>
408- // CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x96xf32> -> vector<2x3x4x1x4x1xf32>
409- // CHECK: vector.extract %[[C_SIMT]][0, 0]
410- // CHECK: amdgpu.mfma
411- // CHECK: %[[INS0:.+]] = vector.insert %{{.+}}, %[[INIT]] [0, 0]
412- // CHECK: vector.extract %[[C_SIMT]][0, 1]
413- // CHECK: amdgpu.mfma
414- // CHECK: %[[INS1:.+]] = vector.insert %{{.+}}, %[[INS0]] [0, 1]
415- // CHECK: vector.extract %[[C_SIMT]][0, 2]
416- // CHECK: amdgpu.mfma
417- // CHECK: %[[INS2:.+]] = vector.insert %{{.+}}, %[[INS1]] [0, 2]
418- // CHECK: vector.extract %[[C_SIMT]][1, 0]
419- // CHECK: amdgpu.mfma
420- // CHECK: %[[INS3:.+]] = vector.insert %{{.+}}, %[[INS2]] [1, 0]
421- // CHECK: vector.extract %[[C_SIMT]][1, 1]
422- // CHECK: amdgpu.mfma
423- // CHECK: %[[INS4:.+]] = vector.insert %{{.+}}, %[[INS3]] [1, 1]
424- // CHECK: vector.extract %[[C_SIMT]][1, 2]
425- // CHECK: amdgpu.mfma
426- // CHECK: %[[INS5:.+]] = vector.insert %{{.+}}, %[[INS4]] [1, 2]
427- // CHECK: iree_vector_ext.to_simd %[[INS5]]
408+ // CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mm_mnbatch_order
409+ // CHECK: %[[C_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x96xf32> -> vector<2x3x4x1x4x1xf32>
410+ // CHECK: vector.extract %[[C_SIMT]][0, 0]
411+ // CHECK: amdgpu.mfma
412+ // CHECK: vector.extract %[[C_SIMT]][0, 1]
413+ // CHECK: amdgpu.mfma
414+ // CHECK: vector.extract %[[C_SIMT]][0, 2]
415+ // CHECK: amdgpu.mfma
416+ // CHECK: vector.extract %[[C_SIMT]][1, 0]
417+ // CHECK: amdgpu.mfma
418+ // CHECK: vector.extract %[[C_SIMT]][1, 1]
419+ // CHECK: amdgpu.mfma
420+ // CHECK: vector.extract %[[C_SIMT]][1, 2]
421+ // CHECK: amdgpu.mfma
422+ // CHECK-COUNT-6: vector.to_elements {{.*}} : vector<4x1x4x1xf32>
423+ // CHECK: %[[INS:.+]] = vector.from_elements
424+ // CHECK: iree_vector_ext.to_simd %[[INS]]
428425
429426// -----
430427
@@ -495,15 +492,17 @@ builtin.module attributes { transform.with_named_sequence } {
495492}
496493
497494// CHECK-LABEL: func.func @contract_to_mfma_32x32x8_mmt
498- // CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x2x4x1x4x1xf32>
499495// CHECK: %[[B_SIMT:.+]] = iree_vector_ext.to_simt %{{.+}} : vector<64x8xf16> -> vector<2x1x1x1x1x4xf16>
500496// CHECK: vector.extract %[[B_SIMT]][0, 0]
501497// CHECK: amdgpu.mfma
502- // CHECK: %[[INS0:.+]] = vector.insert %{{.+}}, %[[INIT]] [0, 0]
503498// CHECK: vector.extract %[[B_SIMT]][1, 0]
504499// CHECK: amdgpu.mfma
505- // CHECK: %[[INS1:.+]] = vector.insert %17, %[[INS0]] [0, 1]
506- // CHECK: iree_vector_ext.to_simd %[[INS1]] : vector<1x2x4x1x4x1xf32> -> vector<32x64xf32>
500+ // CHECK: %[[R0:.+]]:16 = vector.to_elements %{{.+}} : vector<4x1x4x1xf32>
501+ // CHECK: %[[R1:.+]]:16 = vector.to_elements %{{.+}} : vector<4x1x4x1xf32>
502+ // CHECK: %[[INS:.+]] = vector.from_elements
503+ // CHECK-SAME: %[[R0]]#0, %[[R0]]#1, %[[R0]]#2, %[[R0]]#3, %[[R0]]#4, %[[R0]]#5, %[[R0]]#6, %[[R0]]#7, %[[R0]]#8, %[[R0]]#9, %[[R0]]#10, %[[R0]]#11, %[[R0]]#12, %[[R0]]#13, %[[R0]]#14, %[[R0]]#15
504+ // CHECK-SAME: %[[R1]]#0, %[[R1]]#1, %[[R1]]#2, %[[R1]]#3, %[[R1]]#4, %[[R1]]#5, %[[R1]]#6, %[[R1]]#7, %[[R1]]#8, %[[R1]]#9, %[[R1]]#10, %[[R1]]#11, %[[R1]]#12, %[[R1]]#13, %[[R1]]#14, %[[R1]]#15
505+ // CHECK: iree_vector_ext.to_simd %[[INS]] : vector<1x2x4x1x4x1xf32> -> vector<32x64xf32>
507506
508507// -----
509508
@@ -838,6 +837,7 @@ builtin.module attributes { transform.with_named_sequence } {
838837// CHECK: %[[B_CAST_1:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x1x1x8xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
839838// CHECK: %[[MFMA_1:.*]] = amdgpu.mfma %[[A_CAST_1]] * %[[B_CAST_1]] + %[[MFMA_0]]
840839// CHECK-SAME: {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none
840+ // CHECK: %[[MFMA_1_CAST:.*]] = vector.shape_cast %[[MFMA_1]] : vector<4xf32> to vector<1x1x4x1xf32>
841841// CHECK: %[[B_CAST_2:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x1x1x8xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
842842// CHECK: %[[C_CAST_1:.+]] = vector.shape_cast %{{.+}} : vector<1x1x4x1xf32> to vector<4xf32>
843843// CHECK: %[[MFMA_2:.*]] = amdgpu.mfma %[[A_CAST]] * %[[B_CAST_2]] + %[[C_CAST_1]]
@@ -846,6 +846,10 @@ builtin.module attributes { transform.with_named_sequence } {
846846// CHECK: %[[MFMA_3:.*]] = amdgpu.mfma %[[A_CAST_1]] * %[[B_CAST_3]] + %[[MFMA_2]]
847847// CHECK-SAME: {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none
848848// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA_3]] : vector<4xf32> to vector<1x1x4x1xf32>
849- // CHECK: %[[B_OUT:.*]] = vector.insert %[[R_CAST]]
849+ // CHECK: %[[R0:.+]]:4 = vector.to_elements %[[MFMA_1_CAST]] : vector<1x1x4x1xf32>
850+ // CHECK: %[[R1:.+]]:4 = vector.to_elements %[[R_CAST]] : vector<1x1x4x1xf32>
851+ // CHECK: %[[B_OUT:.+]] = vector.from_elements
852+ // CHECK-SAME: %[[R0]]#0, %[[R0]]#1, %[[R0]]#2, %[[R0]]#3
853+ // CHECK-SAME: %[[R1]]#0, %[[R1]]#1, %[[R1]]#2, %[[R1]]#3
850854// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x2x1x1x4x1xf32> -> vector<32x32xf32>
851855// CHECK: return %[[R_SIMD]]
0 commit comments