Skip to content

Commit 9c09c35

Browse files
committed
[mlir][amdgpu] Add scaled_ext_packed{8,16} operations
1 parent 251ae55 commit 9c09c35

File tree

3 files changed

+174
-1
lines changed

3 files changed

+174
-1
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,54 @@ def AMDGPU_ExtPackedFp8Op :
112112
}];
113113
}
114114

115+
def IsValidBlockSize: AttrConstraint<
116+
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getInt() == 16 || ::llvm::cast<::mlir::IntegerAttr>($_self).getInt() == 32">,
117+
"whose value is 16 or 32">;
118+
119+
def AMDGPU_ScaledExtPacked816Op
120+
: AMDGPU_Op<"scaled_ext_packed816", [Pure]>,
121+
Arguments<(
122+
ins AnyTypeOf<[VectorOfLengthAndType<[8], [F4E2M1FN,F8E4M3FN,F8E5M2]>,
123+
VectorOfLengthAndType<[16], [F6E2M3FN, F6E3M2FN]>]>:$source,
124+
FixedVectorOfLengthAndType<[4], [F8E8M0FNU]>:$scale,
125+
ConfinedAttr<I32Attr, [IsValidBlockSize]>:$blockSize,
126+
ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>:$firstScaleLane,
127+
ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<2>]>:$firstScaleByte)>,
128+
Results<(
129+
outs AnyTypeOf<[FixedVectorOfLengthAndType<[8], [F32]>,
130+
FixedVectorOfLengthAndType<[8], [F16]>,
131+
FixedVectorOfLengthAndType<[8], [BF16]>,
132+
FixedVectorOfLengthAndType<[16], [F32]>,
133+
FixedVectorOfLengthAndType<[16], [F16]>,
134+
FixedVectorOfLengthAndType<[16], [BF16]>]>:$res)> {
135+
136+
let summary = "Extend a vector of packed floating point values";
137+
138+
let description = [{
139+
The scales applied to the input microfloats are stored in two bytes which
140+
come from the `scales` input provided in a *half* of the wave identified
141+
by `firstScaleLane`. The pair of bytes used is selected by
142+
`firstScaleByte`. The 16 vectors in consecutive lanes starting from
143+
`firstScaleLane` (which we'll call the scale vectors) will be used by both
144+
halves of the wave (with lane L reading from L % 16'th scale vector), but
145+
each half will use a different byte.
146+
147+
When the block size is 32, `firstScaleByte` can be either 0 or 2,
148+
selecting halves of the scale vectors. Lanes 0-15 will read from
149+
`firstScaleByte` and lanes 16-31 will read from `firstScaleByte` + 1.
150+
151+
However, when the block size is 16, `firstScaleByte` can be 0 or 1.
152+
Lanes 0-15 read from the `firstScaleByte`th element of the scale vectors,
153+
while lanes 16-31 read from `firstScaleByte` + 2.
154+
155+
Note: the layout for the scales generally mirrors how the WMMA
156+
instructions use for matix scales. These selection operands allows
157+
one to choose portions of the matrix to convert.
158+
}];
159+
160+
let hasCustomAssemblyFormat = 1;
161+
}
162+
115163
def AMDGPU_ScaledExtPackedOp
116164
: AMDGPU_Op<"scaled_ext_packed", [Pure]>,
117165
Arguments<(
@@ -860,7 +908,7 @@ def AMDGPU_MFMAOp :
860908
based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
861909
types of the source and destination arguments.
862910

863-
For information on the layouts of the input and output matrces (which are stored
911+
For information on the layouts of the input and output matrices (which are stored
864912
in `sourceA`, `sourceB`, `destC`, and `destD`), see the CDNA ISA documentation.
865913

866914
The `cbsz`, `abid`, and `blgp` parameters control how the lanes of the wave

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,76 @@ void RawBufferAtomicCmpswapOp::getCanonicalizationPatterns(
338338
context);
339339
}
340340

341+
//===----------------------------------------------------------------------===//
342+
// ScaledExtPacked816Op
343+
//===----------------------------------------------------------------------===//
344+
mlir::ParseResult ScaledExtPacked816Op::parse(mlir::OpAsmParser &parser,
345+
mlir::OperationState &result) {
346+
// Parse attributes
347+
if (parser.parseOptionalAttrDict(result.attributes))
348+
return failure();
349+
350+
// Parse source operand
351+
OpAsmParser::UnresolvedOperand source;
352+
if (parser.parseOperand(source))
353+
return failure();
354+
355+
if (parser.parseKeyword("scale") || parser.parseLParen())
356+
return failure();
357+
OpAsmParser::UnresolvedOperand scale;
358+
if (parser.parseOperand(scale) || parser.parseRParen())
359+
return failure();
360+
361+
// Parse attributes
362+
IntegerAttr blockSize, firstScaleLane, firstScaleByte;
363+
if (parser.parseKeyword("blockSize") || parser.parseLParen() ||
364+
parser.parseAttribute(blockSize, parser.getBuilder().getI32Type()) ||
365+
parser.parseRParen())
366+
return failure();
367+
368+
if (parser.parseKeyword("firstScaleLane") || parser.parseLParen() ||
369+
parser.parseAttribute(firstScaleLane, parser.getBuilder().getI32Type()) ||
370+
parser.parseRParen())
371+
return failure();
372+
373+
if (parser.parseKeyword("firstScaleByte") || parser.parseLParen() ||
374+
parser.parseAttribute(firstScaleByte, parser.getBuilder().getI32Type()) ||
375+
parser.parseRParen())
376+
return failure();
377+
378+
Type sourceType, resultType;
379+
if (parser.parseColon() || parser.parseType(sourceType) ||
380+
parser.parseKeyword("to") || parser.parseType(resultType))
381+
return failure();
382+
383+
// Resolve operands with types
384+
Type scaleType =
385+
VectorType::get({4}, Float8E8M0FNUType::get(parser.getContext()));
386+
if (parser.resolveOperand(source, sourceType, result.operands) ||
387+
parser.resolveOperand(scale, scaleType, result.operands))
388+
return failure();
389+
390+
result.addAttribute("blockSize", blockSize);
391+
result.addAttribute("firstScaleLane", firstScaleLane);
392+
result.addAttribute("firstScaleByte", firstScaleByte);
393+
394+
result.addTypes(resultType);
395+
return success();
396+
}
397+
398+
void ScaledExtPacked816Op::print(OpAsmPrinter &p) {
399+
p << " ";
400+
p.printOptionalAttrDict(
401+
(*this)->getAttrs(),
402+
/*elideAttrs=*/{"blockSize", "firstScaleLane", "firstScaleByte"});
403+
p << " " << getSource();
404+
p << " scale(" << getScale() << ")";
405+
p << " blockSize(" << getBlockSize() << ")";
406+
p << " firstScaleLane(" << getFirstScaleLane() << ")";
407+
p << " firstScaleByte(" << getFirstScaleByte() << ")";
408+
p << " : " << getSource().getType() << " to " << getRes().getType();
409+
}
410+
341411
//===----------------------------------------------------------------------===//
342412
// WMMAOp
343413
//===----------------------------------------------------------------------===//

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,61 @@ func.func @scaled_ext_scalar_f4e2m1_bf16(%v: vector<2xf4E2M1FN>, %scale: f32) ->
221221
func.return %ret : vector<2xbf16>
222222
}
223223

224+
// CHECK-LABEL: func.func @scaled_ext_packed816_fp4
225+
func.func @scaled_ext_packed816_fp4(%v: vector<8xf4E2M1FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
226+
// CHECK: amdgpu.scaled_ext_packed816
227+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN> to vector<8xf16>
228+
// CHECK: amdgpu.scaled_ext_packed816
229+
%ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN> to vector<8xbf16>
230+
// CHECK: amdgpu.scaled_ext_packed816
231+
%ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf4E2M1FN> to vector<8xf32>
232+
func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
233+
}
234+
235+
// CHECK-LABEL: func.func @scaled_ext_packed816_fp8
236+
func.func @scaled_ext_packed816_fp8(%v: vector<8xf8E4M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
237+
// CHECK: amdgpu.scaled_ext_packed816
238+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN> to vector<8xf16>
239+
// CHECK: amdgpu.scaled_ext_packed816
240+
%ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN> to vector<8xbf16>
241+
// CHECK: amdgpu.scaled_ext_packed816
242+
%ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E4M3FN> to vector<8xf32>
243+
func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
244+
}
245+
246+
// CHECK-LABEL: func.func @scaled_ext_packed816_bf8
247+
func.func @scaled_ext_packed816_bf8(%v: vector<8xf8E5M2>, %scale: vector<4xf8E8M0FNU>) -> (vector<8xf16>, vector<8xbf16>, vector<8xf32>) {
248+
// CHECK: amdgpu.scaled_ext_packed816
249+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2> to vector<8xf16>
250+
// CHECK: amdgpu.scaled_ext_packed816
251+
%ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2> to vector<8xbf16>
252+
// CHECK: amdgpu.scaled_ext_packed816
253+
%ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2> to vector<8xf32>
254+
func.return %ret0, %ret1, %ret2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
255+
}
256+
257+
// CHECK-LABEL: func.func @scaled_ext_packed816_fp6
258+
func.func @scaled_ext_packed816_fp6(%v: vector<16xf6E2M3FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
259+
// CHECK: amdgpu.scaled_ext_packed816
260+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN> to vector<16xf16>
261+
// CHECK: amdgpu.scaled_ext_packed816
262+
%ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN> to vector<16xbf16>
263+
// CHECK: amdgpu.scaled_ext_packed816
264+
%ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E2M3FN> to vector<16xf32>
265+
func.return %ret0, %ret1, %ret2 : vector<16xf16>, vector<16xbf16>, vector<16xf32>
266+
}
267+
268+
// CHECK-LABEL: func.func @scaled_ext_packed816_bf16
269+
func.func @scaled_ext_packed816_bf16(%v: vector<16xf6E3M2FN>, %scale: vector<4xf8E8M0FNU>) -> (vector<16xf16>, vector<16xbf16>, vector<16xf32>) {
270+
// CHECK: amdgpu.scaled_ext_packed816
271+
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN> to vector<16xf16>
272+
// CHECK: amdgpu.scaled_ext_packed816
273+
%ret1 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN> to vector<16xbf16>
274+
// CHECK: amdgpu.scaled_ext_packed816
275+
%ret2 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(32) firstScaleLane(0) firstScaleByte(0) : vector<16xf6E3M2FN> to vector<16xf32>
276+
func.return %ret0, %ret1, %ret2 : vector<16xf16>, vector<16xbf16>, vector<16xf32>
277+
}
278+
224279
// CHECK-LABEL: func.func @packed_scaled_trunc_f8e4m3_f32
225280
// CHECK: amdgpu.packed_scaled_trunc
226281
func.func @packed_scaled_trunc_f8e4m3_f32(%v: vector<2xf32>, %scale: f32) -> vector<4xf8E4M3FN> {

0 commit comments

Comments
 (0)