Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
7b64631
Update documentation
amd-eochoalo Nov 10, 2025
08e96b1
Fix verifiers
amd-eochoalo Nov 10, 2025
d0932cc
[mlir][amdgpu] Convert scaled_ext_packed816 to rocdl
amd-eochoalo Nov 10, 2025
0d1d668
Create skeleton for pattern
amd-eochoalo Nov 10, 2025
163b15a
Initial conversion
amd-eochoalo Nov 10, 2025
6d7e2a6
Add first test
amd-eochoalo Nov 10, 2025
fc5d858
Adds two more cases
amd-eochoalo Nov 10, 2025
d9a254f
Add case for pk8.bf16.fp4
amd-eochoalo Nov 11, 2025
c5eb698
Add conversion for pk8.bf16.bf8
amd-eochoalo Nov 11, 2025
cec5f04
Fix and add new case
amd-eochoalo Nov 11, 2025
7dc3442
Add case for pk8.f32.fp4
amd-eochoalo Nov 11, 2025
0ba6b94
Add case for pk8.f32.fp8
amd-eochoalo Nov 11, 2025
551849e
Add case for pk8.f32.bf8
amd-eochoalo Nov 11, 2025
c1e10c8
Add case for pk16.f16.bf6
amd-eochoalo Nov 11, 2025
e958d56
Add case for pk16.bf16.fp6
amd-eochoalo Nov 11, 2025
1f79bdd
Add case for pk16.bf16.bf6
amd-eochoalo Nov 11, 2025
c5628e6
Add case for pk16.f32.fp6
amd-eochoalo Nov 11, 2025
db56c98
Add case for pk16.f32.bf6
amd-eochoalo Nov 11, 2025
73ed4b7
Refactor NFC
amd-eochoalo Nov 11, 2025
94cc740
Refactor NFC
amd-eochoalo Nov 11, 2025
4f27e04
Use method instead of isa
amd-eochoalo Nov 11, 2025
0f8f3c4
Hoist variable
amd-eochoalo Nov 11, 2025
a7a853e
Refactor NFC
amd-eochoalo Nov 11, 2025
9e4ab0e
Hoist variable. NFC
amd-eochoalo Nov 11, 2025
728686f
Comments. NFC
amd-eochoalo Nov 11, 2025
47dc32e
refactor. nfc
amd-eochoalo Nov 11, 2025
3cfea7e
Keep conventions
amd-eochoalo Nov 11, 2025
2b010cd
Less of exhaustive enumeration
amd-eochoalo Nov 17, 2025
6f07ef0
Correct types
amd-eochoalo Nov 17, 2025
6978793
Reflow comment
amd-eochoalo Nov 17, 2025
ed66571
superfluous empty line
amd-eochoalo Nov 17, 2025
33ef57e
Using
amd-eochoalo Nov 17, 2025
a83cec9
Add chipset check and moved tests
amd-eochoalo Nov 17, 2025
34ed3e9
Refactor NFC
amd-eochoalo Nov 17, 2025
b88f7f6
Use operation name
amd-eochoalo Nov 17, 2025
7a7ecaf
Convert result type
amd-eochoalo Nov 17, 2025
1025e2b
Check for type conversion failures
amd-eochoalo Nov 17, 2025
a3db728
Merge branch 'main' into eochoa/2025-11-10/lowerings
amd-eochoalo Nov 17, 2025
7c44f09
Add top-level if condition for each src type
amd-eochoalo Nov 17, 2025
f06a67e
Add chipset constant at beginning of file
amd-eochoalo Nov 17, 2025
1dbcb95
wip
amd-eochoalo Nov 17, 2025
eee0ce9
Add invalid srcElemType case
amd-eochoalo Nov 17, 2025
9860cdd
Update verifiers
amd-eochoalo Nov 17, 2025
0b8f561
Update verifier message
amd-eochoalo Nov 17, 2025
642414a
Fix assertion
amd-eochoalo Nov 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def AMDGPU_ScaledExtPacked816Op
FixedVectorOfShapeAndType<[4], F8E8M0FNU>:$scale,
ConfinedAttr<I32Attr, [IsValidBlockSize]>:$blockSize,
ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>:$firstScaleLane,
ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<2>]>:$firstScaleByte)>,
ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<3>]>:$firstScaleByte)>,
Results<(
outs AnyTypeOf<[FixedVectorOfShapeAndType<[8], F32>,
FixedVectorOfShapeAndType<[8], F16>,
Expand All @@ -139,17 +139,21 @@ def AMDGPU_ScaledExtPacked816Op
let summary = "Extend a vector of packed floating point values";

let description = [{
The scales applied to the input microfloats are stored in two bytes which
The scales applied to the input microfloats are stored in bytes which
come from the `scales` input provided in a *half* of the wave identified
by `firstScaleLane`. The pair of bytes used is selected by
`firstScaleByte`. The 16 vectors in consecutive lanes starting from
by `firstScaleLane`. The bytes used is selected by `firstScaleByte` and depends
on the type of `source`. The 16 vectors in consecutive lanes starting from
`firstScaleLane` (which we'll call the scale vectors) will be used by both
halves of the wave (with lane L reading from L % 16'th scale vector), but
each half will use a different byte.
halves of the wave (with lane L reading from L % 16'th scale vector).

When `source` is either F4E2M1FN, F6E2M3FN, or F6E3M2FN each half of the
wave will use a different byte. The first one being `firstScaleByte` and
the second one being `firstScaleByte` + 1. When the block size is 32,
`firstScaleByte` can be either 0 or 2, selecting halves of the scale vectors.
Lanes 0-15 will read from `firstScaleByte` and lanes 16-31 will read
from `firstScaleByte` + 1.


When the block size is 32, `firstScaleByte` can be either 0 or 2,
selecting halves of the scale vectors. Lanes 0-15 will read from
`firstScaleByte` and lanes 16-31 will read from `firstScaleByte` + 1.
For example:
```mlir
// Input: 8-element vector of F8E4M3FN, converting to F32
Expand All @@ -165,7 +169,8 @@ def AMDGPU_ScaledExtPacked816Op
: vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
```

However, when the block size is 16, `firstScaleByte` can be 0 or 1.
When `source` is either F4E2M1FN, F6E2M3FN, or F6E3M2FN and
the block size is 16, `firstScaleByte` can be 0 or 1.
Lanes 0-15 read from the `firstScaleByte`th element of the scale vectors,
while lanes 16-31 read from `firstScaleByte` + 2.
For example:
Expand All @@ -187,6 +192,16 @@ def AMDGPU_ScaledExtPacked816Op
instructions use for matix scales. These selection operands allows
one to choose portions of the matrix to convert.

When `source` is either F8E4M3FN or F8E5M2 and `blockSize` is 32,
then the same byte will be used by both halves of the wave.
In this case, `firstScaleByte` can be any value from 0 to 3.

When `source` is either F8E4M3FN or F8E5M2 and `blockSize` is 16,
following combinations are allowed:
* `firstScaleLane(0), firstScaleByte(0)`
* `firstScaleLane(1), firstScaleByte(2)`
all other combinations are reserved.

Available on gfx1250+.
}];

Expand Down
198 changes: 194 additions & 4 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1492,6 +1492,19 @@ struct ExtPackedFp8OpLowering final
ConversionPatternRewriter &rewriter) const override;
};

struct ScaledExtPacked816OpLowering final
: public ConvertOpToLLVMPattern<ScaledExtPacked816Op> {
ScaledExtPacked816OpLowering(const LLVMTypeConverter &converter,
Chipset chipset)
: ConvertOpToLLVMPattern<amdgpu::ScaledExtPacked816Op>(converter),
chipset(chipset) {}
Chipset chipset;

LogicalResult
matchAndRewrite(ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

struct PackedTrunc2xFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
Expand Down Expand Up @@ -1600,6 +1613,182 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
return success();
}

int getScaleSel(int blockSize, int bitWidth, int firstScaleLane,
int firstScaleByte) {
// When lowering amdgpu.scaled_ext_packed816 to
// rocdl.cvt.scale.pk*.f*.f* operations, the
// attributes blockSize, sourceType, firstScaleLane and firstScaleByte
// are merged into a single attribute scaleSel.
//
// This is how those values are merged together.
assert(llvm::is_contained({16, 32}, blockSize));
assert(llvm::is_contained({4, 6, 8}, bitWidth));

const bool is_fp8 = bitWidth == 8;
const bool is_block_16 = blockSize == 16;

if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) {
return 0b000;
}
if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) {
return 0b001;
}
if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) {
return 0b010;
}
if (!is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && is_block_16) {
return 0b011;
}
if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && !is_block_16) {
return 0b100;
}
if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 0 && is_block_16) {
return 0b101;
}
if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) {
return 0b110;
}
if (!is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) {
return 0b111;
}

if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && !is_block_16) {
return 0b0000;
}
if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 0 && is_block_16) {
return 0b0001;
}
if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 1 && !is_block_16) {
return 0b0010;
}
if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 2 && !is_block_16) {
return 0b0100;
}
if (is_fp8 && firstScaleLane == 0 && firstScaleByte == 3 && !is_block_16) {
return 0b0110;
}
if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 1 && !is_block_16) {
return 0b1010;
}
if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && !is_block_16) {
return 0b1100;
}
if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 2 && is_block_16) {
return 0b1101;
}
if (is_fp8 && firstScaleLane == 1 && firstScaleByte == 3 && !is_block_16) {
return 0b1110;
}

llvm_unreachable("invalid combination of firstScaleLane, firstScaleByte, "
"blockSize and type.");
return 0;
}

LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

int firstScaleLane = op.getFirstScaleLane();
int firstScaleByte = op.getFirstScaleByte();
int blockSize = op.getBlockSize();
auto sourceType = cast<VectorType>(op.getSource().getType());
auto srcElemType = cast<FloatType>(sourceType.getElementType());
int bitWidth = srcElemType.getWidth();
int scaleSel =
getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte);

auto targetType = cast<VectorType>(op.getResult().getType());
auto destElemType = cast<FloatType>(targetType.getElementType());
Location loc = op.getLoc();
IntegerType i32 = rewriter.getI32Type();
Value castedScale =
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());

Value source = adaptor.getSource();
Type packedType;
if (isa<Float4E2M1FNType>(srcElemType)) {
packedType = i32;
packedType = getTypeConverter()->convertType(packedType);
} else if (isa<Float8E4M3FNType>(srcElemType) ||
isa<Float8E5M2Type>(srcElemType)) {
packedType = VectorType::get(2, i32);
packedType = getTypeConverter()->convertType(packedType);
} else if (isa<Float6E2M3FNType>(srcElemType) ||
isa<Float6E3M2FNType>(srcElemType)) {
packedType = VectorType::get(3, i32);
packedType = getTypeConverter()->convertType(packedType);
} else {
llvm_unreachable("invalid element type for scaled ext");
}
// smallT = [Fp4, Fp8, Bf8]
// Bf8 = E5M2
// Fp8 = E4M3
//
// largeT = [F16, Bf16, F32]
// CvtPkScalePk8${largeT}${smallT}
Value castedSource =
LLVM::BitcastOp::create(rewriter, loc, packedType, source);

if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F16Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8Bf16Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float4E2M1FNType>(srcElemType) && destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp4Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E4M3FNType>(srcElemType) && destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Fp8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float8E5M2Type>(srcElemType) && destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk8F32Bf8Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
}
// smallT = [Fp6, Bf6]
// Fp6 = Float6E2M3FN
// Bf6 = Float6E3M2FN
// largeT = [F16, Bf16, F32]
//
// CvtPkScalePk16${largeT}${smallT}
else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F16Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isBF16()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16Bf16Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E2M3FNType>(srcElemType) && destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Fp6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else if (isa<Float6E3M2FNType>(srcElemType) && destElemType.isF32()) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkScalePk16F32Bf6Op>(
op, op.getResult().getType(), castedSource, castedScale, scaleSel);
} else {
return failure();
}

return success();
}

LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
ScaledExtPackedOp op, ScaledExtPackedOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -2138,9 +2327,10 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPacked816OpLowering,
ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
GatherToLDSOpLowering, TransposeLoadOpLowering,
AMDGPUPermlaneLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
23 changes: 18 additions & 5 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,14 +344,27 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
LogicalResult ScaledExtPacked816Op::verify() {
int blockSize = getBlockSize();
assert((blockSize == 16 || blockSize == 32) && "invalid block size");

int firstScaleByte = getFirstScaleByte();
if (blockSize == 16 && !llvm::is_contained({0, 1}, firstScaleByte)) {
return emitOpError(
"blockSize of 16 can only have firstScaleByte be 0 or 1.");
auto sourceType = cast<VectorType>(getSource().getType());
Type elementType = sourceType.getElementType();
auto floatType = cast<FloatType>(elementType);
int bitWidth = floatType.getWidth();

if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 16 &&
!llvm::is_contained({0, 1}, firstScaleByte)) {
return emitOpError("blockSize of 16 can only have firstScaleByte be 0 or 1 "
"for f4 and f6.");
}
if (llvm::is_contained({4, 6}, bitWidth) && blockSize == 32 &&
!llvm::is_contained({0, 2}, firstScaleByte)) {
return emitOpError("blockSize of 32 can only have firstScaleByte be 0 or 2 "
"for f4 and f6.");
}
if (blockSize == 32 && !llvm::is_contained({0, 2}, firstScaleByte)) {
if (bitWidth == 8 && blockSize == 16 &&
!llvm::is_contained({0, 2}, firstScaleByte)) {
return emitOpError(
"blockSize of 32 can only have firstScaleByte be 0 or 2.");
"blockSize of 16 can only have firstScaleByte be 0 or 2 for f8.");
}

return success();
Expand Down
Loading
Loading