@@ -93,10 +93,10 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
9393func.func @wmma_scale_16x16x128_fp8 (%arg0 : vector <64 xf8 E4 M3 FN>, %arg1 : vector <64 xf6 E2 M3 FN>,
9494 %arg2 : vector <8 xf32 >, %arg3 : vector <4 xf8 E8 M0 FNU>) {
9595 // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
96- %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 [ 0 ] * %arg0 ) * (%arg3 [ 0 ] * %arg0 ) + %arg2 : vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <8 xf32 >
96+ %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 * %arg0 ) * (%arg3 * %arg0 ) + %arg2 { scaleAIdx = 0 : i32 , scaleBIdx = 0 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <8 xf32 >
9797
9898 // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtA = 2 : i32, fmtB = 2 : i32, scaleAType = 1 : i32} : (vector<12xi32>, vector<12xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
99- %1 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 [ 1 ] * %arg1 ) * (%arg3 [ 0 ] * %arg1 ) + %arg2 : vector <4 xf8 E8 M0 FNU>, vector <64 xf6 E2 M3 FN>, vector <4 xf8 E8 M0 FNU>, vector <64 xf6 E2 M3 FN>, vector <8 xf32 >
99+ %1 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 * %arg1 ) * (%arg3 * %arg1 ) + %arg2 { scaleAIdx = 1 : i32 , scaleBIdx = 0 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <64 xf6 E2 M3 FN>, vector <4 xf8 E8 M0 FNU>, vector <64 xf6 E2 M3 FN>, vector <8 xf32 >
100100
101101 func.return
102102}
@@ -105,10 +105,10 @@ func.func @wmma_scale_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<
105105func.func @wmma_scale_16x16x128_fp6 (%arg0 : vector <64 xf6 E2 M3 FN>, %arg1 : vector <64 xf6 E3 M2 FN>,
106106 %arg2 : vector <8 xf32 >, %arg3 : vector <4 xf8 E8 M0 FNU>) {
107107 // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtA = 2 : i32, fmtB = 2 : i32} : (vector<12xi32>, vector<12xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
108- %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 [ 0 ] * %arg0 ) * (%arg3 [ 0 ] * %arg0 ) + %arg2 : vector <4 xf8 E8 M0 FNU>, vector <64 xf6 E2 M3 FN>, vector <4 xf8 E8 M0 FNU>, vector <64 xf6 E2 M3 FN>, vector <8 xf32 >
108+ %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 * %arg0 ) * (%arg3 * %arg0 ) + %arg2 { scaleAIdx = 0 : i32 , scaleBIdx = 0 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <64 xf6 E2 M3 FN>, vector <4 xf8 E8 M0 FNU>, vector <64 xf6 E2 M3 FN>, vector <8 xf32 >
109109
110110 // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtA = 3 : i32, fmtB = 3 : i32} : (vector<12xi32>, vector<12xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
111- %1 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 [ 0 ] * %arg1 ) * (%arg3 [ 0 ] * %arg1 ) + %arg2 : vector <4 xf8 E8 M0 FNU>, vector <64 xf6 E3 M2 FN>, vector <4 xf8 E8 M0 FNU>, vector <64 xf6 E3 M2 FN>, vector <8 xf32 >
111+ %1 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 * %arg1 ) * (%arg3 * %arg1 ) + %arg2 { scaleAIdx = 0 : i32 , scaleBIdx = 0 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <64 xf6 E3 M2 FN>, vector <4 xf8 E8 M0 FNU>, vector <64 xf6 E3 M2 FN>, vector <8 xf32 >
112112
113113 func.return
114114}
@@ -118,10 +118,10 @@ func.func @wmma_scale_16x16x128_mixed(%arg0 : vector<64xf8E4M3FN>, %arg1 : vecto
118118 %arg2 : vector <64 xf4 E2 M1 FN>, %arg3 : vector <8 xf32 >,
119119 %arg4 : vector <4 xf8 E8 M0 FNU>, %arg5 : vector <4 xf8 E4 M3 FN>) {
120120 // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, {{.*}}, {{.*}} {fmtB = 4 : i32, fmtScaleB = 2 : i32} : (vector<16xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
121- %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg4 [ 0 ] * %arg0 ) * (%arg5 [ 0 ] * %arg2 ) + %arg3 : vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <4 xf8 E4 M3 FN>, vector <64 xf4 E2 M1 FN>, vector <8 xf32 >
121+ %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg4 * %arg0 ) * (%arg5 * %arg2 ) + %arg3 { scaleAIdx = 0 : i32 , scaleBIdx = 0 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <4 xf8 E4 M3 FN>, vector <64 xf4 E2 M1 FN>, vector <8 xf32 >
122122
123123 // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, {{.*}}, {{.*}} {fmtA = 2 : i32, fmtB = 4 : i32, fmtScaleB = 2 : i32} : (vector<12xi32>, vector<8xi32>, vector<8xf32>, i32, i32) -> vector<8xf32>
124- %1 = amdgpu.scaled_wmma 16 x16 x128 (%arg4 [ 0 ] * %arg1 ) * (%arg5 [ 0 ] * %arg2 ) + %arg3 : vector <4 xf8 E8 M0 FNU>, vector <64 xf6 E2 M3 FN>, vector <4 xf8 E4 M3 FN>, vector <64 xf4 E2 M1 FN>, vector <8 xf32 >
124+ %1 = amdgpu.scaled_wmma 16 x16 x128 (%arg4 * %arg1 ) * (%arg5 * %arg2 ) + %arg3 { scaleAIdx = 0 : i32 , scaleBIdx = 0 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <64 xf6 E2 M3 FN>, vector <4 xf8 E4 M3 FN>, vector <64 xf4 E2 M1 FN>, vector <8 xf32 >
125125
126126 func.return
127127}
@@ -130,10 +130,10 @@ func.func @wmma_scale_16x16x128_mixed(%arg0 : vector<64xf8E4M3FN>, %arg1 : vecto
130130func.func @wmma_scale16_16x16x128_fp8 (%arg0 : vector <64 xf8 E4 M3 FN>, %arg1 : vector <64 xf6 E3 M2 FN>,
131131 %arg2 : vector <8 xf32 >, %arg3 : vector <8 xf8 E8 M0 FNU>) {
132132 // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf32>, i64, i64) -> vector<8xf32>
133- %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 [ 0 ] * %arg0 ) * (%arg3 [ 0 ] * %arg0 ) + %arg2 : vector <8 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <8 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <8 xf32 >
133+ %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 * %arg0 ) * (%arg3 * %arg0 ) + %arg2 { scaleAIdx = 0 : i32 , scaleBIdx = 0 : i32 } : vector <8 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <8 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <8 xf32 >
134134
135135 // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtA = 3 : i32, fmtB = 3 : i32, scaleAType = 1 : i32} : (vector<12xi32>, vector<12xi32>, vector<8xf32>, i64, i64) -> vector<8xf32>
136- %1 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 [ 1 ] * %arg1 ) * (%arg3 [ 0 ] * %arg1 ) + %arg2 : vector <8 xf8 E8 M0 FNU>, vector <64 xf6 E3 M2 FN>, vector <8 xf8 E8 M0 FNU>, vector <64 xf6 E3 M2 FN>, vector <8 xf32 >
136+ %1 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 * %arg1 ) * (%arg3 * %arg1 ) + %arg2 { scaleAIdx = 1 : i32 , scaleBIdx = 0 : i32 } : vector <8 xf8 E8 M0 FNU>, vector <64 xf6 E3 M2 FN>, vector <8 xf8 E8 M0 FNU>, vector <64 xf6 E3 M2 FN>, vector <8 xf32 >
137137
138138 func.return
139139}
@@ -142,7 +142,7 @@ func.func @wmma_scale16_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vecto
142142func.func @wmma_scale_32x16x128_fp4 (%arg0 : vector <128 xf4 E2 M1 FN>, %arg1 : vector <64 xf4 E2 M1 FN>,
143143 %arg2 : vector <16 xf32 >, %arg3 : vector <4 xf8 E4 M3 FN>) {
144144 // CHECK: rocdl.wmma.scale.f32.32x16x128.f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtScaleA = 2 : i32, fmtScaleB = 2 : i32} : (vector<16xi32>, vector<8xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
145- %0 = amdgpu.scaled_wmma 32 x16 x128 (%arg3 [ 0 ] * %arg0 ) * (%arg3 [ 0 ] * %arg1 ) + %arg2 : vector <4 xf8 E4 M3 FN>, vector <128 xf4 E2 M1 FN>, vector <4 xf8 E4 M3 FN>, vector <64 xf4 E2 M1 FN>, vector <16 xf32 >
145+ %0 = amdgpu.scaled_wmma 32 x16 x128 (%arg3 * %arg0 ) * (%arg3 * %arg1 ) + %arg2 { scaleAIdx = 0 : i32 , scaleBIdx = 0 : i32 } : vector <4 xf8 E4 M3 FN>, vector <128 xf4 E2 M1 FN>, vector <4 xf8 E4 M3 FN>, vector <64 xf4 E2 M1 FN>, vector <16 xf32 >
146146
147147 func.return
148148}
@@ -151,7 +151,7 @@ func.func @wmma_scale_32x16x128_fp4(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector
151151func.func @wmma_scale16_32x16x128_fp4 (%arg0 : vector <128 xf4 E2 M1 FN>, %arg1 : vector <64 xf4 E2 M1 FN>,
152152 %arg2 : vector <16 xf32 >, %arg3 : vector <8 xf8 E4 M3 FN>) {
153153 // CHECK: rocdl.wmma.scale16.f32.32x16x128.f4 {{.*}}, {{.*}}, %arg2, {{.*}}, {{.*}} {fmtScaleA = 2 : i32, fmtScaleB = 2 : i32} : (vector<16xi32>, vector<8xi32>, vector<16xf32>, i64, i64) -> vector<16xf32>
154- %0 = amdgpu.scaled_wmma 32 x16 x128 (%arg3 [ 0 ] * %arg0 ) * (%arg3 [ 0 ] * %arg1 ) + %arg2 : vector <8 xf8 E4 M3 FN>, vector <128 xf4 E2 M1 FN>, vector <8 xf8 E4 M3 FN>, vector <64 xf4 E2 M1 FN>, vector <16 xf32 >
154+ %0 = amdgpu.scaled_wmma 32 x16 x128 (%arg3 * %arg0 ) * (%arg3 * %arg1 ) + %arg2 { scaleAIdx = 0 : i32 , scaleBIdx = 0 : i32 } : vector <8 xf8 E4 M3 FN>, vector <128 xf4 E2 M1 FN>, vector <8 xf8 E4 M3 FN>, vector <64 xf4 E2 M1 FN>, vector <16 xf32 >
155155
156156 func.return
157157}
@@ -170,42 +170,42 @@ func.func @wmma_unsupported_k(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) {
170170func.func @scaled_wmma_wrong_output_length (%arg0 : vector <64 xf8 E4 M3 FN>, %arg1 : vector <16 xf32 >,
171171 %arg2 : vector <4 xf8 E8 M0 FNU>) {
172172 // expected-error@below {{'amdgpu.scaled_wmma' op expected output vector of length 8 but got 16}}
173- %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg2 [ 0 ] * %arg0 ) * (%arg2 [ 0 ] * %arg0 ) + %arg1 : vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <16 xf32 >
173+ %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg2 * %arg0 ) * (%arg2 * %arg0 ) + %arg1 { scaleAIdx = 0 : i32 , scaleBIdx = 0 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <16 xf32 >
174174 return
175175}
176176
177177func.func @scaled_wmma_16x16_wrong_sourceA_length (%arg0 : vector <128 xf4 E2 M1 FN>, %arg1 : vector <64 xf4 E2 M1 FN>,
178178 %arg2 : vector <8 xf32 >, %arg3 : vector <4 xf8 E8 M0 FNU>) {
179179 // expected-error@below {{'amdgpu.scaled_wmma' op for 16x16x128, sourceA must have 64 elements but got 128}}
180- %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 [ 0 ] * %arg0 ) * (%arg3 [ 0 ] * %arg1 ) + %arg2 : vector <4 xf8 E8 M0 FNU>, vector <128 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <64 xf4 E2 M1 FN>, vector <8 xf32 >
180+ %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 * %arg0 ) * (%arg3 * %arg1 ) + %arg2 { scaleAIdx = 0 : i32 , scaleBIdx = 0 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <128 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <64 xf4 E2 M1 FN>, vector <8 xf32 >
181181 return
182182}
183183
184184func.func @scaled_wmma_16x16_wrong_sourceB_length (%arg0 : vector <64 xf8 E4 M3 FN>, %arg1 : vector <128 xf4 E2 M1 FN>,
185185 %arg2 : vector <8 xf32 >, %arg3 : vector <4 xf8 E8 M0 FNU>) {
186186 // expected-error@below {{'amdgpu.scaled_wmma' op for 16x16x128, sourceB must have 64 elements but got 128}}
187- %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 [ 0 ] * %arg0 ) * (%arg3 [ 0 ] * %arg1 ) + %arg2 : vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <4 xf8 E8 M0 FNU>, vector <128 xf4 E2 M1 FN>, vector <8 xf32 >
187+ %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 * %arg0 ) * (%arg3 * %arg1 ) + %arg2 { scaleAIdx = 0 : i32 , scaleBIdx = 0 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <4 xf8 E8 M0 FNU>, vector <128 xf4 E2 M1 FN>, vector <8 xf32 >
188188 return
189189}
190190
191191func.func @scaled_wmma_32x16_wrong_sourceA_length (%arg0 : vector <64 xf4 E2 M1 FN>, %arg1 : vector <64 xf4 E2 M1 FN>,
192192 %arg2 : vector <16 xf32 >, %arg3 : vector <4 xf8 E4 M3 FN>) {
193193 // expected-error@below {{'amdgpu.scaled_wmma' op for 32x16x128, sourceA must have 128 elements but got 64}}
194- %0 = amdgpu.scaled_wmma 32 x16 x128 (%arg3 [ 0 ] * %arg0 ) * (%arg3 [ 0 ] * %arg1 ) + %arg2 : vector <4 xf8 E4 M3 FN>, vector <64 xf4 E2 M1 FN>, vector <4 xf8 E4 M3 FN>, vector <64 xf4 E2 M1 FN>, vector <16 xf32 >
194+ %0 = amdgpu.scaled_wmma 32 x16 x128 (%arg3 * %arg0 ) * (%arg3 * %arg1 ) + %arg2 { scaleAIdx = 0 : i32 , scaleBIdx = 0 : i32 } : vector <4 xf8 E4 M3 FN>, vector <64 xf4 E2 M1 FN>, vector <4 xf8 E4 M3 FN>, vector <64 xf4 E2 M1 FN>, vector <16 xf32 >
195195 return
196196}
197197
198198func.func @scaled_wmma_32x16_wrong_sourceB_length (%arg0 : vector <128 xf4 E2 M1 FN>, %arg1 : vector <128 xf4 E2 M1 FN>,
199199 %arg2 : vector <16 xf32 >, %arg3 : vector <4 xf8 E4 M3 FN>) {
200200 // expected-error@below {{'amdgpu.scaled_wmma' op for 32x16x128, sourceB must have 64 elements but got 128}}
201- %0 = amdgpu.scaled_wmma 32 x16 x128 (%arg3 [ 0 ] * %arg0 ) * (%arg3 [ 0 ] * %arg1 ) + %arg2 : vector <4 xf8 E4 M3 FN>, vector <128 xf4 E2 M1 FN>, vector <4 xf8 E4 M3 FN>, vector <128 xf4 E2 M1 FN>, vector <16 xf32 >
201+ %0 = amdgpu.scaled_wmma 32 x16 x128 (%arg3 * %arg0 ) * (%arg3 * %arg1 ) + %arg2 { scaleAIdx = 0 : i32 , scaleBIdx = 0 : i32 } : vector <4 xf8 E4 M3 FN>, vector <128 xf4 E2 M1 FN>, vector <4 xf8 E4 M3 FN>, vector <128 xf4 E2 M1 FN>, vector <16 xf32 >
202202 return
203203}
204204
205205func.func @scaled_wmma_invalid_type_combination (%arg0 : vector <64 xf8 E4 M3 FN>, %arg1 : vector <64 xf6 E2 M3 FN>,
206206 %arg2 : vector <8 xf32 >, %arg3 : vector <4 xf8 E8 M0 FNU>,
207207 %arg4 : vector <4 xf8 E4 M3 FN>) {
208208 // expected-error@below {{'amdgpu.scaled_wmma' op invalid combination of matrix and scale types}}
209- %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 [ 0 ] * %arg0 ) * (%arg4 [ 0 ] * %arg1 ) + %arg2 : vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <4 xf8 E4 M3 FN>, vector <64 xf6 E2 M3 FN>, vector <8 xf32 >
209+ %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 * %arg0 ) * (%arg4 * %arg1 ) + %arg2 { scaleAIdx = 0 : i32 , scaleBIdx = 0 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <4 xf8 E4 M3 FN>, vector <64 xf6 E2 M3 FN>, vector <8 xf32 >
210210 return
211211}
0 commit comments