Skip to content

Commit 9860cdd

Browse files
committed
Update verifiers
1 parent eee0ce9 commit 9860cdd

File tree

3 files changed

+37
-17
lines changed

3 files changed

+37
-17
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1634,7 +1634,7 @@ int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
16341634
// firstScaleByte are merged into a single attribute scaleSel. This is how
16351635
// those values are merged together.
16361636
assert(llvm::is_contained({16, 32}, blockSize));
1637-
assert(llvm::is_contained(::llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
1637+
assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
16381638

16391639
const bool is_fp8 = bitWidth == 8;
16401640
const bool is_block_16 = blockSize == 16;

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -343,28 +343,41 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
343343
//===----------------------------------------------------------------------===//
344344
LogicalResult ScaledExtPacked816Op::verify() {
345345
int blockSize = getBlockSize();
346-
assert((blockSize == 16 || blockSize == 32) && "invalid block size");
346+
assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size");
347347

348348
int firstScaleByte = getFirstScaleByte();
349+
int firstScaleLane = getFirstScaleLane();
349350
auto sourceType = cast<VectorType>(getSource().getType());
350351
Type elementType = sourceType.getElementType();
351352
auto floatType = cast<FloatType>(elementType);
352-
int bitWidth = floatType.getWidth();
353+
unsigned bitWidth = floatType.getWidth();
353354

354-
if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 16 &&
355-
!llvm::is_contained({0, 1}, firstScaleByte)) {
356-
return emitOpError("blockSize of 16 can only have firstScaleByte be 0 or 1 "
357-
"for f4 and f6.");
358-
}
359-
if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 32 &&
360-
!llvm::is_contained({0, 2}, firstScaleByte)) {
361-
return emitOpError("blockSize of 32 can only have firstScaleByte be 0 or 2 "
362-
"for f4 and f6.");
363-
}
364-
if (bitWidth == 8 && blockSize == 16 &&
365-
!llvm::is_contained({0, 2}, firstScaleByte)) {
366-
return emitOpError(
367-
"blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.");
355+
assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
356+
357+
const bool is_fp8 = bitWidth == 8;
358+
const bool is_block_16 = blockSize == 16;
359+
360+
if (!is_fp8) {
361+
if (is_block_16) {
362+
if (!llvm::is_contained({0, 1}, firstScaleByte)) {
363+
return emitOpError("blockSize of 16 can only have firstScaleByte be 0 "
364+
"or 1 for f4 and f6.");
365+
}
366+
} else {
367+
if (!llvm::is_contained({0, 2}, firstScaleByte)) {
368+
return emitOpError("blockSize of 32 can only have firstScaleByte be 0 "
369+
"or 2 for f4 and f6.");
370+
}
371+
}
372+
} else {
373+
if (is_block_16) {
374+
bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
375+
((firstScaleLane == 1) && (firstScaleByte == 2));
376+
if (!is_valid) {
377+
return emitOpError(
378+
"blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.");
379+
}
380+
}
368381
}
369382

370383
return success();

mlir/test/Conversion/AMDGPUToROCDL/cvt_scale_pk-gfx1250.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,10 @@ func.func @amdgpu.scaled_ext_packed816_invalid_src_elem_type(%v: vector<16xf16>,
155155
return %ret0: vector<16xf16>
156156
}
157157

158+
// -----
159+
160+
func.func @amdgpu.scaled_ext_packed816_invalid_dst_elem_type(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf64>) {
161+
// expected-error@+1 {{'amdgpu.scaled_ext_packed816' op result #0 must be vector}}
162+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf64>
163+
return %ret0: vector<16xf64>
164+
}

0 commit comments

Comments
 (0)