Skip to content

Commit e209b8b

Browse files
[mlir][AMDGPU] Rename gfx1250 packed extension ops, change firstScaleLane (#170718)
The current name of scaled_ext_packed816 was, in retrospect, bothering me, since it just has a bunch of numbers on the end and doesn't really reflect the wave-wide nature of the operation. On top of that, the fact that firstScaleLane was 0 or 1, which might be read as the first lane being 1 (and not what it actually was, 16), also seemed weird. Therefore, before this op sees any use, 1. Renaem it to scaled_ext_packed_matrix 2. Change the semantics of firstScaleLane to actually point at the lane where the scales start (valid options currently are 0 or 16, the two halves of a wave32 wave). (Disclaimer: the mechanical updates were done via AI.) --------- Co-authored-by: Erick Ochoa Lopez <[email protected]>
1 parent 6969ac8 commit e209b8b

File tree

5 files changed

+156
-154
lines changed

5 files changed

+156
-154
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -146,21 +146,17 @@ def AMDGPU_ExtPackedFp8Op :
146146
}];
147147
}
148148

149-
def IsValidBlockSize: AttrConstraint<
150-
CPred<"::llvm::is_contained({16, 32}, ::llvm::cast<::mlir::IntegerAttr>($_self).getInt())">,
151-
"whose value is 16 or 32">;
152-
153-
def AMDGPU_ScaledExtPacked816Op
154-
: AMDGPU_Op<"scaled_ext_packed816", [Pure, AllShapesMatch<["source", "res"]>]>,
149+
def AMDGPU_ScaledExtPackedMatrixOp
150+
: AMDGPU_Op<"scaled_ext_packed_matrix", [Pure, AllShapesMatch<["source", "res"]>]>,
155151
Arguments<(
156152
ins AnyTypeOf<[FixedVectorOfShapeAndType<[8], F4E2M1FN>,
157153
FixedVectorOfShapeAndType<[8], F8E4M3FN>,
158154
FixedVectorOfShapeAndType<[8], F8E5M2>,
159155
FixedVectorOfShapeAndType<[16], F6E2M3FN>,
160156
FixedVectorOfShapeAndType<[16], F6E3M2FN>]>:$source,
161157
FixedVectorOfShapeAndType<[4], F8E8M0FNU>:$scale,
162-
ConfinedAttr<I32Attr, [IsValidBlockSize]>:$blockSize,
163-
ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>:$firstScaleLane,
158+
ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$blockSize,
159+
ConfinedAttr<I32Attr, [IntIsOneOf<[0, 16]>]>:$firstScaleLane,
164160
ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<3>]>:$firstScaleByte)>,
165161
Results<(
166162
outs AnyTypeOf<[FixedVectorOfShapeAndType<[8], F32>,
@@ -170,9 +166,12 @@ def AMDGPU_ScaledExtPacked816Op
170166
FixedVectorOfShapeAndType<[16], F16>,
171167
FixedVectorOfShapeAndType<[16], BF16>]>:$res)> {
172168

173-
let summary = "Extend a vector of packed floating point values";
169+
let summary = "Extend a wave-wide matrix of packed floating point values";
174170

175171
let description = [{
172+
Extend matrix of microfloats (8 or 16 elements per lane) using a set of scales
173+
that may be stored on other lanes.
174+
176175
The scales applied to the input microfloats are stored in bytes which
177176
come from the `scales` input provided in a *half* of the wave identified
178177
by `firstScaleLane`. The bytes used is selected by `firstScaleByte` and depends
@@ -192,14 +191,14 @@ def AMDGPU_ScaledExtPacked816Op
192191
```mlir
193192
// Input: 8-element vector of F8E4M3FN, converting to F32
194193
// Lanes 0-15 read from byte 0, lanes 16-31 read from byte 1
195-
%result = amdgpu.scaled_ext_packed816 %source scale(%scales)
194+
%result = amdgpu.scaled_ext_packed_matrix %source scale(%scales)
196195
blockSize(32) firstScaleLane(0) firstScaleByte(0)
197196
: vector<8xf8E4M3FN>, vector<4xf8E8M0FNU> -> vector<8xf32>
198197

199198
// Input: 16-element vector of F6E2M3FN, converting to F16
200199
// Lanes 0-15 read from byte 2, lanes 16-31 read from byte 3
201-
%result = amdgpu.scaled_ext_packed816 %source scale(%scales)
202-
blockSize(32) firstScaleLane(1) firstScaleByte(2)
200+
%result = amdgpu.scaled_ext_packed_matrix %source scale(%scales)
201+
blockSize(32) firstScaleLane(16) firstScaleByte(2)
203202
: vector<16xf6E2M3FN>, vector<4xf8E8M0FNU> -> vector<16xf16>
204203
```
205204

@@ -211,19 +210,19 @@ def AMDGPU_ScaledExtPacked816Op
211210
```mlir
212211
// Input: 8-element vector of F8E5M2, converting to BF16
213212
// Lanes 0-15 read from byte 0, lanes 16-31 read from byte 2 (0+2)
214-
%result = amdgpu.scaled_ext_packed816 %source scale(%scales)
213+
%result = amdgpu.scaled_ext_packed_matrix %source scale(%scales)
215214
blockSize(16) firstScaleLane(0) firstScaleByte(0)
216215
: vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<8xbf16>
217216

218217
// Input: 16-element vector of F6E3M2FN, converting to F32
219218
// Lanes 0-15 read from byte 1, lanes 16-31 read from byte 3 (1+2)
220-
%result = amdgpu.scaled_ext_packed816 %source scale(%scales)
221-
blockSize(16) firstScaleLane(1) firstScaleByte(1)
219+
%result = amdgpu.scaled_ext_packed_matrix %source scale(%scales)
220+
blockSize(16) firstScaleLane(16) firstScaleByte(1)
222221
: vector<16xf6E3M2FN>, vector<4xf8E8M0FNU> -> vector<16xf32>
223222
```
224223

225224
Note: the layout for the scales generally mirrors how the WMMA
226-
instructions use for matix scales. These selection operands allows
225+
instructions use for matrix scales. These selection operands allows
227226
one to choose portions of the matrix to convert.
228227

229228
When `source` is either F8E4M3FN or F8E5M2 and `blockSize` is 32,
@@ -233,7 +232,7 @@ def AMDGPU_ScaledExtPacked816Op
233232
When `source` is either F8E4M3FN or F8E5M2 and `blockSize` is 16,
234233
following combinations are allowed:
235234
* `firstScaleLane(0), firstScaleByte(0)`
236-
* `firstScaleLane(1), firstScaleByte(2)`
235+
* `firstScaleLane(16), firstScaleByte(2)`
237236
all other combinations are reserved.
238237

239238
Available on gfx1250+.

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 53 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,16 +1506,17 @@ struct ExtPackedFp8OpLowering final
15061506
ConversionPatternRewriter &rewriter) const override;
15071507
};
15081508

1509-
struct ScaledExtPacked816OpLowering final
1510-
: public ConvertOpToLLVMPattern<ScaledExtPacked816Op> {
1511-
ScaledExtPacked816OpLowering(const LLVMTypeConverter &converter,
1512-
Chipset chipset)
1513-
: ConvertOpToLLVMPattern<amdgpu::ScaledExtPacked816Op>(converter),
1509+
struct ScaledExtPackedMatrixOpLowering final
1510+
: public ConvertOpToLLVMPattern<ScaledExtPackedMatrixOp> {
1511+
ScaledExtPackedMatrixOpLowering(const LLVMTypeConverter &converter,
1512+
Chipset chipset)
1513+
: ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedMatrixOp>(converter),
15141514
chipset(chipset) {}
15151515
Chipset chipset;
15161516

15171517
LogicalResult
1518-
matchAndRewrite(ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
1518+
matchAndRewrite(ScaledExtPackedMatrixOp op,
1519+
ScaledExtPackedMatrixOpAdaptor adaptor,
15191520
ConversionPatternRewriter &rewriter) const override;
15201521
};
15211522

@@ -1627,34 +1628,35 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
16271628
return success();
16281629
}
16291630

1630-
int32_t getScaleSel(int32_t blockSize, unsigned bitWidth,
1631-
int32_t firstScaleLane, int32_t firstScaleByte) {
1632-
// When lowering amdgpu.scaled_ext_packed816 to rocdl.cvt.scale.pk*.f*.f*
1633-
// operations, the attributes blockSize, sourceType, firstScaleLane and
1631+
int32_t getScaleSel(int32_t blockSize, unsigned bitWidth, int32_t scaleWaveHalf,
1632+
int32_t firstScaleByte) {
1633+
// When lowering amdgpu.scaled_ext_packed_matrix to rocdl.cvt.scale.pk*.f*.f*
1634+
// operations, the attributes blockSize, sourceType, scaleWaveHalf, and
16341635
// firstScaleByte are merged into a single attribute scaleSel. This is how
1635-
// those values are merged together.
1636+
// those values are merged together. (Note: scaleWaveHalf isn't a high-level
1637+
// attribute but is derifed from firstScaleLane).
16361638
assert(llvm::is_contained({16, 32}, blockSize));
16371639
assert(llvm::is_contained(llvm::ArrayRef<unsigned>{4, 6, 8}, bitWidth));
16381640

1639-
const bool is_fp8 = bitWidth == 8;
1640-
const bool is_block_16 = blockSize == 16;
1641+
const bool isFp8 = bitWidth == 8;
1642+
const bool isBlock16 = blockSize == 16;
16411643

1642-
if (!is_fp8) {
1643-
int bit_0 = is_block_16;
1644+
if (!isFp8) {
1645+
int32_t bit0 = isBlock16;
16441646
assert(llvm::is_contained({0, 1, 2}, firstScaleByte));
1645-
int bit_1 = (firstScaleByte == 2) << 1;
1646-
assert(llvm::is_contained({0, 1}, firstScaleLane));
1647-
int bit_2 = firstScaleLane << 2;
1648-
return bit_2 | bit_1 | bit_0;
1647+
int32_t bit1 = (firstScaleByte == 2) << 1;
1648+
assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1649+
int32_t bit2 = scaleWaveHalf << 2;
1650+
return bit2 | bit1 | bit0;
16491651
}
16501652

1651-
int bit_0 = is_block_16;
1653+
int32_t bit0 = isBlock16;
16521654
// firstScaleByte is guaranteed to be defined by two bits.
16531655
assert(llvm::is_contained({0, 1, 2, 3}, firstScaleByte));
1654-
int bit_2_and_1 = firstScaleByte << 1;
1655-
assert(llvm::is_contained({0, 1}, firstScaleLane));
1656-
int bit_3 = firstScaleLane << 3;
1657-
int bits = bit_3 | bit_2_and_1 | bit_0;
1656+
int32_t bits2and1 = firstScaleByte << 1;
1657+
assert(llvm::is_contained({0, 1}, scaleWaveHalf));
1658+
int32_t bit3 = scaleWaveHalf << 3;
1659+
int32_t bits = bit3 | bits2and1 | bit0;
16581660
// These are invalid cases.
16591661
assert(!llvm::is_contained(
16601662
{0b0011, 0b0101, 0b0111, 0b1000, 0b1001, 0b1011, 0b1111}, bits));
@@ -1717,8 +1719,8 @@ scaledExtPacked816ToIntrinsic(Type srcElemType, Type destElemType) {
17171719
"instructions");
17181720
}
17191721

1720-
LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
1721-
ScaledExtPacked816Op op, ScaledExtPacked816OpAdaptor adaptor,
1722+
LogicalResult ScaledExtPackedMatrixOpLowering::matchAndRewrite(
1723+
ScaledExtPackedMatrixOp op, ScaledExtPackedMatrixOpAdaptor adaptor,
17221724
ConversionPatternRewriter &rewriter) const {
17231725
using fp4 = Float4E2M1FNType;
17241726
using fp8 = Float8E4M3FNType;
@@ -1732,7 +1734,9 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
17321734
"Scaled fp packed conversion instructions are not available on target "
17331735
"architecture and their emulation is not implemented");
17341736
}
1735-
int32_t firstScaleLane = op.getFirstScaleLane();
1737+
// Convert user-facing firstScaleLane (0 or 16) to the half of the wave that
1738+
// is being selected.
1739+
int32_t scaleWaveHalf = op.getFirstScaleLane() / 16;
17361740
int32_t firstScaleByte = op.getFirstScaleByte();
17371741
int32_t blockSize = op.getBlockSize();
17381742
auto sourceType = cast<VectorType>(op.getSource().getType());
@@ -1770,7 +1774,7 @@ LogicalResult ScaledExtPacked816OpLowering::matchAndRewrite(
17701774
"no intrinsic matching packed scaled conversion on the given chipset");
17711775

17721776
int32_t scaleSel =
1773-
getScaleSel(blockSize, bitWidth, firstScaleLane, firstScaleByte);
1777+
getScaleSel(blockSize, bitWidth, scaleWaveHalf, firstScaleByte);
17741778
Value castedScale =
17751779
LLVM::BitcastOp::create(rewriter, loc, i32, adaptor.getScale());
17761780
Value castedSource =
@@ -2388,27 +2392,26 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
23882392
RewritePatternSet &patterns,
23892393
Chipset chipset) {
23902394
populateAMDGPUMemorySpaceAttributeConversions(converter);
2391-
patterns
2392-
.add<FatRawBufferCastLowering,
2393-
RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
2394-
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
2395-
RawBufferOpLowering<RawBufferAtomicFaddOp,
2396-
ROCDL::RawPtrBufferAtomicFaddOp>,
2397-
RawBufferOpLowering<RawBufferAtomicFmaxOp,
2398-
ROCDL::RawPtrBufferAtomicFmaxOp>,
2399-
RawBufferOpLowering<RawBufferAtomicSmaxOp,
2400-
ROCDL::RawPtrBufferAtomicSmaxOp>,
2401-
RawBufferOpLowering<RawBufferAtomicUminOp,
2402-
ROCDL::RawPtrBufferAtomicUminOp>,
2403-
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
2404-
ROCDL::RawPtrBufferAtomicCmpSwap>,
2405-
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
2406-
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2407-
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPacked816OpLowering,
2408-
ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
2409-
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
2410-
GatherToLDSOpLowering, TransposeLoadOpLowering,
2411-
AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering>(converter,
2412-
chipset);
2395+
patterns.add<
2396+
FatRawBufferCastLowering,
2397+
RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
2398+
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
2399+
RawBufferOpLowering<RawBufferAtomicFaddOp,
2400+
ROCDL::RawPtrBufferAtomicFaddOp>,
2401+
RawBufferOpLowering<RawBufferAtomicFmaxOp,
2402+
ROCDL::RawPtrBufferAtomicFmaxOp>,
2403+
RawBufferOpLowering<RawBufferAtomicSmaxOp,
2404+
ROCDL::RawPtrBufferAtomicSmaxOp>,
2405+
RawBufferOpLowering<RawBufferAtomicUminOp,
2406+
ROCDL::RawPtrBufferAtomicUminOp>,
2407+
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
2408+
ROCDL::RawPtrBufferAtomicCmpSwap>,
2409+
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
2410+
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
2411+
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering,
2412+
ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
2413+
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
2414+
GatherToLDSOpLowering, TransposeLoadOpLowering, AMDGPUPermlaneLowering,
2415+
AMDGPUMakeDmaBaseLowering>(converter, chipset);
24132416
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
24142417
}

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -343,9 +343,9 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
343343
}
344344

345345
//===----------------------------------------------------------------------===//
346-
// ScaledExtPacked816Op
346+
// ScaledExtPackedMatrixOp
347347
//===----------------------------------------------------------------------===//
348-
LogicalResult ScaledExtPacked816Op::verify() {
348+
LogicalResult ScaledExtPackedMatrixOp::verify() {
349349
int blockSize = getBlockSize();
350350
assert(llvm::is_contained({16, 32}, blockSize) && "invalid block size");
351351

@@ -376,10 +376,10 @@ LogicalResult ScaledExtPacked816Op::verify() {
376376
} else {
377377
if (is_block_16) {
378378
bool is_valid = ((firstScaleLane == 0) && (firstScaleByte == 0)) ||
379-
((firstScaleLane == 1) && (firstScaleByte == 2));
379+
((firstScaleLane == 16) && (firstScaleByte == 2));
380380
if (!is_valid) {
381381
return emitOpError("blockSize of 16 can only have (firstScaleLane, "
382-
"firstScaleByte) be (0, 0) or (1, 2) for f8.");
382+
"firstScaleByte) be (0, 0) or (16, 2) for f8.");
383383
}
384384
}
385385
}

0 commit comments

Comments
 (0)