@@ -43,8 +43,7 @@ module attributes {wave.normal_form = #wave.normal_form<full_types,memory_only_t
4343 %acc = wave.register %cst_f32 : vector <4 xf32 >
4444
4545 // CHECK-NOT: wave.mma
46- // CHECK: amdgpu.mfma %[[LHS]] * %[[RHS]] + %[[ACC]]
47- // CHECK-SAME: blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32
46+ // CHECK: amdgpu.mfma 16x16x16 %[[LHS]] * %[[RHS]] + %[[ACC]]
4847 // CHECK-SAME: blgp = none
4948 // CHECK-SAME: vector<4xf16>, vector<4xf16>, vector<4xf32>
5049 %res = wave.mma %lhs , %rhs , %acc {kind = #wave.mma_kind <f32 _16 x16 x16 _f16 >}
@@ -93,88 +92,88 @@ module attributes {wave.normal_form = #wave.normal_form<full_types,memory_only_t
9392 // f16 kinds
9493 // CHECK-NOT: wave.mma
9594 // CHECK: amdgpu.mfma
96- // CHECK-SAME: k = 16 : i32, m = 16 : i32, n = 16 : i32
95+ // CHECK-SAME: 16x16x16
9796 %0 = wave.mma %lhs_f16 , %rhs_f16 , %acc_f32_4 {kind = #wave.mma_kind <f32 _16 x16 x16 _f16 >}
9897 : (vector <4 xf16 >, vector <4 xf16 >, vector <4 xf32 >) -> vector <4 xf32 >
9998 // CHECK-NOT: wave.mma
10099 // CHECK: amdgpu.mfma
101- // CHECK-SAME: k = 8 : i32, m = 32 : i32, n = 32 : i32
100+ // CHECK-SAME: 32x32x8
102101 %1 = wave.mma %lhs_f16 , %rhs_f16 , %acc_f32_16 {kind = #wave.mma_kind <f32 _32 x32 x8 _f16 >}
103102 : (vector <4 xf16 >, vector <4 xf16 >, vector <16 xf32 >) -> vector <16 xf32 >
104103 // CHECK-NOT: wave.mma
105104 // CHECK: amdgpu.mfma
106- // CHECK-SAME: k = 32 : i32, m = 16 : i32, n = 16 : i32
105+ // CHECK-SAME: 16x16x32
107106 %2 = wave.mma %lhs_f16_w8 , %rhs_f16_w8 , %acc_f32_4 {kind = #wave.mma_kind <f32 _16 x16 x32 _k8 _f16 >}
108107 : (vector <8 xf16 >, vector <8 xf16 >, vector <4 xf32 >) -> vector <4 xf32 >
109108 // CHECK-NOT: wave.mma
110109 // CHECK: amdgpu.mfma
111- // CHECK-SAME: k = 16 : i32, m = 32 : i32, n = 32 : i32
110+ // CHECK-SAME: 32x32x16
112111 %3 = wave.mma %lhs_f16_w8 , %rhs_f16_w8 , %acc_f32_16 {kind = #wave.mma_kind <f32 _32 x32 x16 _k8 _f16 >}
113112 : (vector <8 xf16 >, vector <8 xf16 >, vector <16 xf32 >) -> vector <16 xf32 >
114113 // CHECK-NOT: wave.mma
115114 // CHECK: amdgpu.mfma
116- // CHECK-SAME: k = 16 : i32, m = 32 : i32, n = 32 : i32
115+ // CHECK-SAME: 32x32x16
117116 %4 = wave.mma %lhs_f16_w8 , %rhs_f16_w8 , %acc_f32_16 {kind = #wave.mma_kind <f32 _32 x32 x16 _f16 >}
118117 : (vector <8 xf16 >, vector <8 xf16 >, vector <16 xf32 >) -> vector <16 xf32 >
119118 // CHECK-NOT: wave.mma
120119 // CHECK: amdgpu.mfma
121- // CHECK-SAME: k = 32 : i32, m = 16 : i32, n = 16 : i32
120+ // CHECK-SAME: 16x16x32
122121 %5 = wave.mma %lhs_f16_w8 , %rhs_f16_w8 , %acc_f32_4 {kind = #wave.mma_kind <f32 _16 x16 x32 _f16 >}
123122 : (vector <8 xf16 >, vector <8 xf16 >, vector <4 xf32 >) -> vector <4 xf32 >
124123
125124 // bf16 kinds
126125 // CHECK-NOT: wave.mma
127126 // CHECK: amdgpu.mfma
128- // CHECK-SAME: k = 16 : i32, m = 32 : i32, n = 32 : i32
127+ // CHECK-SAME: 32x32x16
129128 %6 = wave.mma %lhs_bf16 , %rhs_bf16 , %acc_f32_16 {kind = #wave.mma_kind <f32 _32 x32 x16 _bf16 >}
130129 : (vector <8 xbf16 >, vector <8 xbf16 >, vector <16 xf32 >) -> vector <16 xf32 >
131130 // CHECK-NOT: wave.mma
132131 // CHECK: amdgpu.mfma
133- // CHECK-SAME: k = 32 : i32, m = 16 : i32, n = 16 : i32
132+ // CHECK-SAME: 16x16x32
134133 %7 = wave.mma %lhs_bf16 , %rhs_bf16 , %acc_f32_4 {kind = #wave.mma_kind <f32 _16 x16 x32 _bf16 >}
135134 : (vector <8 xbf16 >, vector <8 xbf16 >, vector <4 xf32 >) -> vector <4 xf32 >
136135
137136 // f8 kinds
138137 // CHECK-NOT: wave.mma
139138 // CHECK: amdgpu.mfma
140- // CHECK-SAME: k = 32 : i32, m = 16 : i32, n = 16 : i32
139+ // CHECK-SAME: 16x16x32
141140 %8 = wave.mma %lhs_f8 , %rhs_f8 , %acc_f32_4 {kind = #wave.mma_kind <f32 _16 x16 x32 _f8 >}
142141 : (vector <8 xf8 E4 M3 FNUZ>, vector <8 xf8 E4 M3 FNUZ>, vector <4 xf32 >) -> vector <4 xf32 >
143142 // CHECK-NOT: wave.mma
144143 // CHECK: amdgpu.mfma
145- // CHECK-SAME: k = 16 : i32, m = 32 : i32, n = 32 : i32
144+ // CHECK-SAME: 32x32x16
146145 %9 = wave.mma %lhs_f8 , %rhs_f8 , %acc_f32_16 {kind = #wave.mma_kind <f32 _32 x32 x16 _f8 >}
147146 : (vector <8 xf8 E4 M3 FNUZ>, vector <8 xf8 E4 M3 FNUZ>, vector <16 xf32 >) -> vector <16 xf32 >
148147 // CHECK-NOT: wave.mma
149148 // CHECK: amdgpu.mfma
150- // CHECK-SAME: k = 32 : i32, m = 16 : i32, n = 16 : i32
149+ // CHECK-SAME: 16x16x32
151150 %10 = wave.mma %lhs_f8 , %rhs_f8 , %acc_f32_4 {kind = #wave.mma_kind <f32 _16 x16 x32 _k4 _f8 >}
152151 : (vector <8 xf8 E4 M3 FNUZ>, vector <8 xf8 E4 M3 FNUZ>, vector <4 xf32 >) -> vector <4 xf32 >
153152 // CHECK-NOT: wave.mma
154153 // CHECK: amdgpu.mfma
155- // CHECK-SAME: k = 16 : i32, m = 32 : i32, n = 32 : i32
154+ // CHECK-SAME: 32x32x16
156155 %11 = wave.mma %lhs_f8 , %rhs_f8 , %acc_f32_16 {kind = #wave.mma_kind <f32 _32 x32 x16 _k4 _f8 >}
157156 : (vector <8 xf8 E4 M3 FNUZ>, vector <8 xf8 E4 M3 FNUZ>, vector <16 xf32 >) -> vector <16 xf32 >
158157
159158 // i8 kinds
160159 // CHECK-NOT: wave.mma
161160 // CHECK: amdgpu.mfma
162- // CHECK-SAME: k = 16 : i32, m = 16 : i32, n = 16 : i32
161+ // CHECK-SAME: 16x16x16
163162 %12 = wave.mma %lhs_i8 , %rhs_i8 , %acc_i32_4 {kind = #wave.mma_kind <i32 _16 x16 x16 _i8 >}
164163 : (vector <4 xi8 >, vector <4 xi8 >, vector <4 xi32 >) -> vector <4 xi32 >
165164 // CHECK-NOT: wave.mma
166165 // CHECK: amdgpu.mfma
167- // CHECK-SAME: k = 8 : i32, m = 32 : i32, n = 32 : i32
166+ // CHECK-SAME: 32x32x8
168167 %13 = wave.mma %lhs_i8 , %rhs_i8 , %acc_i32_16 {kind = #wave.mma_kind <i32 _32 x32 x8 _i8 >}
169168 : (vector <4 xi8 >, vector <4 xi8 >, vector <16 xi32 >) -> vector <16 xi32 >
170169 // CHECK-NOT: wave.mma
171170 // CHECK: amdgpu.mfma
172- // CHECK-SAME: k = 32 : i32, m = 16 : i32, n = 16 : i32
171+ // CHECK-SAME: 16x16x32
173172 %14 = wave.mma %lhs_i8_w8 , %rhs_i8_w8 , %acc_i32_4 {kind = #wave.mma_kind <i32 _16 x16 x32 _i8 >}
174173 : (vector <8 xi8 >, vector <8 xi8 >, vector <4 xi32 >) -> vector <4 xi32 >
175174 // CHECK-NOT: wave.mma
176175 // CHECK: amdgpu.mfma
177- // CHECK-SAME: k = 16 : i32, m = 32 : i32, n = 32 : i32
176+ // CHECK-SAME: 32x32x16
178177 %15 = wave.mma %lhs_i8_w8 , %rhs_i8_w8 , %acc_i32_16 {kind = #wave.mma_kind <i32 _32 x32 x16 _i8 >}
179178 : (vector <8 xi8 >, vector <8 xi8 >, vector <16 xi32 >) -> vector <16 xi32 >
180179
0 commit comments