Skip to content

Commit 08e96b1

Browse files
committed
Fix verifiers
1 parent 7b64631 commit 08e96b1

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -344,14 +344,27 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
344344
LogicalResult ScaledExtPacked816Op::verify() {
345345
int blockSize = getBlockSize();
346346
assert((blockSize == 16 || blockSize == 32) && "invalid block size");
347+
347348
int firstScaleByte = getFirstScaleByte();
348-
if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) {
349-
return emitOpError(
350-
"blockSize of 16 can only have firstScaleByte be 0 or 1.");
349+
auto sourceType = cast<VectorType>(getSource().getType());
350+
Type elementType = sourceType.getElementType();
351+
auto floatType = cast<FloatType>(elementType);
352+
int bitWidth = floatType.getWidth();
353+
354+
if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 16 &&
355+
!llvm::is_contained({0, 1}, firstScaleByte)) {
356+
return emitOpError("blockSize of 16 can only have firstScaleByte be 0 or 1 "
357+
"for f4 and f6.");
358+
}
359+
if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 32 &&
360+
!llvm::is_contained({0, 2}, firstScaleByte)) {
361+
return emitOpError("blockSize of 32 can only have firstScaleByte be 0 or 2 "
362+
"for f4 and f6.");
351363
}
352-
if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) {
364+
if (bitWidth == 8 && blockSize == 16 &&
365+
!llvm::is_contained({0, 2}, firstScaleByte)) {
353366
return emitOpError(
354-
"blockSize of 32 can only have firstScaleByte be 0 or 2.");
367+
"blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.");
355368
}
356369

357370
return success();

mlir/test/Dialect/AMDGPU/invalid.mlir

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -333,17 +333,25 @@ func.func @gather_to_lds_non_lds(%idx1 : index, %mem1 : memref<32xf16>, %mem2 :
333333

334334
// -----
335335

336-
func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
337-
// expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1.}}
338-
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
336+
func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_16(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) {
337+
// expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 1 for f4 and f6}}
338+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(2) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
339339
func.return
340340
}
341341

342342
// -----
343343

344-
func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
345-
// expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2.}}
346-
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
344+
func.func @amdgpu.scaled_ext_packed816_invalid_block_size_and_first_scale_byte_32(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) {
345+
// expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 32 can only have firstScaleByte be 0 or 2 for f4 and f6.}}
346+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(1) : vector<8xf4E2M1FN>, vector<4xf8E8M0FNU> -> vector<8xf16>
347+
func.return
348+
}
349+
350+
// -----
351+
352+
func.func @amdgpu.scaled_ext_packed816_invalid_attributes_for_f8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) {
353+
// expected-error@+1 {{'amdgpu.scaled_ext_packed816' op blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.}}
354+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(1) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xf16>
347355
func.return
348356
}
349357

0 commit comments

Comments
 (0)