Skip to content

Commit b1d2aa2

Browse files
committed
Add negative tests and remove redundant checks
1 parent c1e65fb commit b1d2aa2

File tree

2 files changed

+45
-9
lines changed

2 files changed

+45
-9
lines changed

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,6 @@ LogicalResult ScaledWMMAOp::verify() {
451451
auto isF8 = llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type>;
452452
auto isF6 = llvm::IsaPred<Float6E2M3FNType, Float6E3M2FNType>;
453453
auto isF4 = llvm::IsaPred<Float4E2M1FNType>;
454-
auto isSmallFloat = [&](Type t) { return isF4(t) || isF6(t) || isF8(t); };
455454
auto isScaleF8 = llvm::IsaPred<Float8E8M0FNUType, Float8E4M3FNType>;
456455
auto isE8M0 = llvm::IsaPred<Float8E8M0FNUType>;
457456
auto isE4M3 = llvm::IsaPred<Float8E4M3FNType>;
@@ -460,18 +459,10 @@ LogicalResult ScaledWMMAOp::verify() {
460459
auto sourceBType = cast<VectorType>(getSourceB().getType());
461460
auto destType = cast<VectorType>(getDestC().getType());
462461

463-
// Validate output type is F32.
464-
if (!destType.getElementType().isF32())
465-
return emitOpError("destination must have f32 element type");
466-
467462
// Validate source element types are small floats (fp4/fp6/fp8).
468463
Type aElemType = sourceAType.getElementType();
469464
Type bElemType = sourceBType.getElementType();
470465

471-
if (!isSmallFloat(aElemType) || !isSmallFloat(bElemType))
472-
return emitOpError("source operands must have small float element types "
473-
"(fp4/fp6/fp8)");
474-
475466
// Validate vector lengths based on dimensions.
476467
int64_t m = getM();
477468
int64_t aLen = sourceAType.getNumElements();

mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,48 @@ func.func @wmma_unsupported_k(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) {
164164
amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<8xf16>, vector<8xf16>, vector<8xf32>
165165
return
166166
}
167+
168+
// -----
169+
170+
func.func @scaled_wmma_wrong_output_length(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<16xf32>,
171+
%arg2 : vector<4xf8E8M0FNU>) {
172+
// expected-error@below {{'amdgpu.scaled_wmma' op expected output vector of length 8 but got 16}}
173+
%0 = amdgpu.scaled_wmma 16x16x128 (%arg2[0] * %arg0) * (%arg2[0] * %arg0) + %arg1 : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<16xf32>
174+
return
175+
}
176+
177+
func.func @scaled_wmma_16x16_wrong_sourceA_length(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
178+
%arg2 : vector<8xf32>, %arg3 : vector<4xf8E8M0FNU>) {
179+
// expected-error@below {{'amdgpu.scaled_wmma' op for 16x16x128, sourceA must have 64 elements but got 128}}
180+
%0 = amdgpu.scaled_wmma 16x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<128xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<64xf4E2M1FN>, vector<8xf32>
181+
return
182+
}
183+
184+
func.func @scaled_wmma_16x16_wrong_sourceB_length(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<128xf4E2M1FN>,
185+
%arg2 : vector<8xf32>, %arg3 : vector<4xf8E8M0FNU>) {
186+
// expected-error@below {{'amdgpu.scaled_wmma' op for 16x16x128, sourceB must have 64 elements but got 128}}
187+
%0 = amdgpu.scaled_wmma 16x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E8M0FNU>, vector<128xf4E2M1FN>, vector<8xf32>
188+
return
189+
}
190+
191+
func.func @scaled_wmma_32x16_wrong_sourceA_length(%arg0 : vector<64xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
192+
%arg2 : vector<16xf32>, %arg3 : vector<4xf8E4M3FN>) {
193+
// expected-error@below {{'amdgpu.scaled_wmma' op for 32x16x128, sourceA must have 128 elements but got 64}}
194+
%0 = amdgpu.scaled_wmma 32x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg1) + %arg2 : vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<4xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
195+
return
196+
}
197+
198+
func.func @scaled_wmma_32x16_wrong_sourceB_length(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<128xf4E2M1FN>,
199+
%arg2 : vector<16xf32>, %arg3 : vector<4xf8E4M3FN>) {
200+
// expected-error@below {{'amdgpu.scaled_wmma' op for 32x16x128, sourceB must have 64 elements but got 128}}
201+
%0 = amdgpu.scaled_wmma 32x16x128 (%arg3[0] * %arg0) * (%arg3[0] * %arg1) + %arg2 : vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<4xf8E4M3FN>, vector<128xf4E2M1FN>, vector<16xf32>
202+
return
203+
}
204+
205+
func.func @scaled_wmma_invalid_type_combination(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E2M3FN>,
206+
%arg2 : vector<8xf32>, %arg3 : vector<4xf8E8M0FNU>,
207+
%arg4 : vector<4xf8E4M3FN>) {
208+
// expected-error@below {{'amdgpu.scaled_wmma' op invalid combination of matrix and scale types}}
209+
%0 = amdgpu.scaled_wmma 16x16x128 (%arg3[0] * %arg0) * (%arg4[0] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<4xf8E4M3FN>, vector<64xf6E2M3FN>, vector<8xf32>
210+
return
211+
}

0 commit comments

Comments
 (0)