@@ -77,8 +77,7 @@ builtin.module attributes { transform.with_named_sequence } {
7777// CHECK: %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x4xf16> to vector<4xf16>
7878// CHECK: %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x4x1xf16> to vector<4xf16>
7979// CHECK: %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<4x1x4x1xf32> to vector<16xf32>
80- // CHECK: %[[MFMA:.+]] = amdgpu.mfma %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
81- // CHECK-SAME: {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none
80+ // CHECK: %[[MFMA:.+]] = amdgpu.mfma 32x32x8 %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]] blgp = none
8281// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<16xf32>
8382// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA]] : vector<16xf32> to vector<4x1x4x1xf32>
8483// CHECK: %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<4x1x4x1xf32> to vector<1x1x4x1x4x1xf32>
@@ -154,8 +153,7 @@ builtin.module attributes { transform.with_named_sequence } {
154153// CHECK: %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x4xf16> to vector<4xf16>
155154// CHECK: %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x4x1xf16> to vector<4xf16>
156155// CHECK: %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<1x1x4x1xf32> to vector<4xf32>
157- // CHECK: %[[MFMA:.+]] = amdgpu.mfma %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
158- // CHECK-SAME: {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none
156+ // CHECK: %[[MFMA:.+]] = amdgpu.mfma 16x16x16 %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]] blgp = none
159157// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<4xf32>
160158
161159// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA]] : vector<4xf32> to vector<1x1x4x1xf32>
@@ -238,13 +236,13 @@ builtin.module attributes { transform.with_named_sequence } {
238236// CHECK: %[[A_SLICE0:.+]] = vector.extract %[[A_SIMT]][0, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
239237// CHECK: %[[A0_CAST:.+]] = vector.shape_cast %[[A_SLICE0]] : vector<1x1x1x4xf16> to vector<4xf16>
240238// CHECK: %[[C0_CAST:.+]] = vector.shape_cast %[[C_SLICE0]] : vector<4x1x4x1xf32> to vector<16xf32>
241- // CHECK: %[[MFMA0:.+]] = amdgpu.mfma %[[A0_CAST]] * %{{.+}} + %[[C0_CAST]]
239+ // CHECK: %[[MFMA0:.+]] = amdgpu.mfma 32x32x8 %[[A0_CAST]] * %{{.+}} + %[[C0_CAST]]
242240// CHECK: %[[R0_CAST:.+]] = vector.shape_cast %[[MFMA0]] : vector<16xf32> to vector<4x1x4x1xf32>
243241// CHECK: %[[C_SLICE1:.+]] = vector.extract %[[C_SIMT]][1, 0] : vector<4x1x4x1xf32> from vector<2x1x4x1x4x1xf32>
244242// CHECK: %[[A_SLICE1:.+]] = vector.extract %[[A_SIMT]][1, 0] : vector<1x1x1x4xf16> from vector<2x1x1x1x1x4xf16>
245243// CHECK: %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]] : vector<1x1x1x4xf16> to vector<4xf16>
246244// CHECK: %[[C1_CAST:.+]] = vector.shape_cast %[[C_SLICE1]] : vector<4x1x4x1xf32> to vector<16xf32>
247- // CHECK: %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %{{.+}} + %[[C1_CAST]]
245+ // CHECK: %[[MFMA1:.+]] = amdgpu.mfma 32x32x8 %[[A1_CAST]] * %{{.+}} + %[[C1_CAST]]
248246// CHECK: %[[R1_CAST:.+]] = vector.shape_cast %[[MFMA1]] : vector<16xf32> to vector<4x1x4x1xf32>
249247// CHECK: %[[R0:.+]]:16 = vector.to_elements %[[R0_CAST]] : vector<4x1x4x1xf32>
250248// CHECK: %[[R1:.+]]:16 = vector.to_elements %[[R1_CAST]] : vector<4x1x4x1xf32>
@@ -329,12 +327,12 @@ builtin.module attributes { transform.with_named_sequence } {
329327// CHECK: %[[B_SLICE0:.+]] = vector.extract %[[B_SIMT]][0, 0]
330328// CHECK: %[[A0_CAST:.+]] = vector.shape_cast %[[A_SLICE0]]
331329// CHECK: %[[B0_CAST:.+]] = vector.shape_cast %[[B_SLICE0]]
332- // CHECK: %[[MFMA0:.+]] = amdgpu.mfma %[[A0_CAST]] * %[[B0_CAST]] + %{{.+}}
330+ // CHECK: %[[MFMA0:.+]] = amdgpu.mfma 32x32x8 %[[A0_CAST]] * %[[B0_CAST]] + %{{.+}}
333331// CHECK: %[[A_SLICE1:.+]] = vector.extract %[[A_SIMT]][0, 1]
334332// CHECK: %[[B_SLICE1:.+]] = vector.extract %[[B_SIMT]][1, 0]
335333// CHECK: %[[A1_CAST:.+]] = vector.shape_cast %[[A_SLICE1]]
336334// CHECK: %[[B1_CAST:.+]] = vector.shape_cast %[[B_SLICE1]]
337- // CHECK: %[[MFMA1:.+]] = amdgpu.mfma %[[A1_CAST]] * %[[B1_CAST]] + %[[MFMA0]]
335+ // CHECK: %[[MFMA1:.+]] = amdgpu.mfma 32x32x8 %[[A1_CAST]] * %[[B1_CAST]] + %[[MFMA0]]
338336// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA1]]
339337
340338// -----
@@ -584,7 +582,7 @@ builtin.module attributes { transform.with_named_sequence } {
584582// CHECK: %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x16xf16> to vector<16xf1
585583// CHECK: %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x16x1xf16> to vector<16xf1
586584// CHECK: %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<8x1x1x1xf32> to vector<8xf32>
587- // CHECK: %[[WMMA:.+]] = amdgpu.wmma %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
585+ // CHECK: %[[WMMA:.+]] = amdgpu.wmma 16x16x16 %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
588586// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[WMMA]] : vector<8xf32> to vector<8x1x1x1xf32>
589587// CHECK: %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<8x1x1x1xf32> to vector<1x1x8x1x1x1xf32>
590588// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x1x8x1x1x1xf32> -> vector<16x16xf32>
@@ -670,7 +668,7 @@ builtin.module attributes { transform.with_named_sequence } {
670668// CHECK: %[[A_CAST:.+]] = vector.shape_cast %[[A_VEC]] : vector<1x1x1x8xf16> to vector<8xf16>
671669// CHECK: %[[B_CAST:.+]] = vector.shape_cast %[[B_VEC]] : vector<1x1x8x1xf16> to vector<8xf16>
672670// CHECK: %[[C_CAST:.+]] = vector.shape_cast %[[C_VEC]] : vector<1x1x8x1xf32> to vector<8xf32>
673- // CHECK: %[[WMMA:.+]] = amdgpu.wmma %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
671+ // CHECK: %[[WMMA:.+]] = amdgpu.wmma 16x16x16 %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
674672// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[WMMA]] : vector<8xf32> to vector<1x1x8x1xf32>
675673// CHECK: %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<1x1x8x1xf32> to vector<1x1x1x1x8x1xf32>
676674// CHECK: %[[R_SIMD:.+]] = iree_vector_ext.to_simd %[[B_OUT]] : vector<1x1x1x1x8x1xf32> -> vector<16x16xf32>
@@ -756,13 +754,11 @@ builtin.module attributes { transform.with_named_sequence } {
756754// CHECK: %[[C_CAST:.+]] = vector.shape_cast %{{.+}} : vector<4x1x4x1xf32> to vector<16xf32>
757755// CHECK: %[[A_SLICE_0:.+]] = vector.extract_strided_slice %[[A_CAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
758756// CHECK: %[[B_SLICE_0:.+]] = vector.extract_strided_slice %[[B_CAST]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
759- // CHECK: %[[MFMA_0:.*]] = amdgpu.mfma %[[A_SLICE_0]] * %[[B_SLICE_0]] + %[[C_CAST]]
760- // CHECK-SAME: {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none
757+ // CHECK: %[[MFMA_0:.*]] = amdgpu.mfma 32x32x8 %[[A_SLICE_0]] * %[[B_SLICE_0]] + %[[C_CAST]] blgp = none
761758// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<16xf32>
762759// CHECK: %[[A_SLICE_1:.+]] = vector.extract_strided_slice %[[A_CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
763760// CHECK: %[[B_SLICE_1:.+]] = vector.extract_strided_slice %[[B_CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16>
764- // CHECK: %[[MFMA_1:.+]] = amdgpu.mfma %[[A_SLICE_1]] * %[[B_SLICE_1]] + %[[MFMA_0]]
765- // CHECK-SAME: {blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32} blgp = none
761+ // CHECK: %[[MFMA_1:.+]] = amdgpu.mfma 32x32x8 %[[A_SLICE_1]] * %[[B_SLICE_1]] + %[[MFMA_0]] blgp = none
766762// CHECK-SAME: : vector<4xf16>, vector<4xf16>, vector<16xf32>
767763// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA_1]] : vector<16xf32> to vector<4x1x4x1xf32>
768764// CHECK: %[[B_OUT:.*]] = vector.broadcast %[[R_CAST]] : vector<4x1x4x1xf32> to vector<1x1x4x1x4x1xf32>
@@ -831,20 +827,16 @@ builtin.module attributes { transform.with_named_sequence } {
831827// CHECK: %[[A_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x1x1x8xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
832828// CHECK: %[[B_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x1x1x8xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
833829// CHECK: %[[C_CAST:.+]] = vector.shape_cast %{{.+}} : vector<1x1x4x1xf32> to vector<4xf32>
834- // CHECK: %[[MFMA_0:.*]] = amdgpu.mfma %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]]
835- // CHECK-SAME: {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none
830+ // CHECK: %[[MFMA_0:.*]] = amdgpu.mfma 16x16x32 %[[A_CAST]] * %[[B_CAST]] + %[[C_CAST]] blgp = none
836831// CHECK: %[[A_CAST_1:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x1x1x8xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
837832// CHECK: %[[B_CAST_1:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x1x1x8xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
838- // CHECK: %[[MFMA_1:.*]] = amdgpu.mfma %[[A_CAST_1]] * %[[B_CAST_1]] + %[[MFMA_0]]
839- // CHECK-SAME: {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none
833+ // CHECK: %[[MFMA_1:.*]] = amdgpu.mfma 16x16x32 %[[A_CAST_1]] * %[[B_CAST_1]] + %[[MFMA_0]] blgp = none
840834// CHECK: %[[MFMA_1_CAST:.*]] = vector.shape_cast %[[MFMA_1]] : vector<4xf32> to vector<1x1x4x1xf32>
841835// CHECK: %[[B_CAST_2:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x1x1x8xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
842836// CHECK: %[[C_CAST_1:.+]] = vector.shape_cast %{{.+}} : vector<1x1x4x1xf32> to vector<4xf32>
843- // CHECK: %[[MFMA_2:.*]] = amdgpu.mfma %[[A_CAST]] * %[[B_CAST_2]] + %[[C_CAST_1]]
844- // CHECK-SAME: {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none
837+ // CHECK: %[[MFMA_2:.*]] = amdgpu.mfma 16x16x32 %[[A_CAST]] * %[[B_CAST_2]] + %[[C_CAST_1]] blgp = none
845838// CHECK: %[[B_CAST_3:.+]] = vector.shape_cast %{{.+}} : vector<1x1x1x1x1x8xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ>
846- // CHECK: %[[MFMA_3:.*]] = amdgpu.mfma %[[A_CAST_1]] * %[[B_CAST_3]] + %[[MFMA_2]]
847- // CHECK-SAME: {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none
839+ // CHECK: %[[MFMA_3:.*]] = amdgpu.mfma 16x16x32 %[[A_CAST_1]] * %[[B_CAST_3]] + %[[MFMA_2]] blgp = none
848840// CHECK: %[[R_CAST:.+]] = vector.shape_cast %[[MFMA_3]] : vector<4xf32> to vector<1x1x4x1xf32>
849841// CHECK: %[[R0:.+]]:4 = vector.to_elements %[[MFMA_1_CAST]] : vector<1x1x4x1xf32>
850842// CHECK: %[[R1:.+]]:4 = vector.to_elements %[[R_CAST]] : vector<1x1x4x1xf32>
0 commit comments