Skip to content

Commit fa1136f

Browse files
authored
Add Fp4 lowering from MIGraphX Dialect (#2089)
Add lowering for Fp4 scaled GEMM in migraphx dialect
1 parent b37df0d commit fa1136f

38 files changed

+1515
-228
lines changed

external/llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,8 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
12541254
bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
12551255
if (isa<FloatType>(type)) {
12561256
return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
1257-
Float8E5M2Type, Float8E4M3FNUZType, Float8E5M2FNUZType>(type);
1257+
Float8E5M2Type, Float8E4M3FNUZType, Float8E5M2FNUZType,
1258+
Float4E2M1FNType, Float8E8M0FNUType>(type);
12581259
}
12591260
if (auto intTy = dyn_cast<IntegerType>(type)) {
12601261
if (intTy.isSignless()) {

mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -523,31 +523,47 @@ def MIGraphX_MultiBroadcastOp :
523523
}];
524524
}
525525

526-
class MIGraphX_DotOpBase<string mnemonic, list<Type> inputTypes=[], list<Type> outputTypes=[]> :
527-
MIGraphX_Op<mnemonic>,
528-
Arguments<(ins MIXRShapedOf<inputTypes>:$in_a,
529-
MIXRShapedOf<inputTypes>:$in_b
530-
)>,
531-
Results<(outs MIXRShapedOf<outputTypes>:$output)> {
532-
let assemblyFormat = [{
533-
$in_a `,` $in_b attr-dict `:` type($in_a) `,` type($in_b) `->` type($output)
534-
}];
535-
}
536-
537-
def MIGraphX_QuantDotOp :
538-
MIGraphX_DotOpBase<"quant_dot", [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2, I8, SI8], [F32, I32, SI32]>{
526+
defvar QuantDotInTypes = [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2, F4E2M1FN,
527+
I8, SI8];
528+
defvar QuantDotOutTypes = [F32, I32, SI32];
529+
defvar QuantDotScaleTypes = [F8E8M0FNU, F32];
530+
531+
def MIGraphX_QuantDotOp
532+
: MIGraphX_Op<"quant_dot", [AttrSizedOperandSegments]>,
533+
Arguments<(ins MIXRShapedOf<QuantDotInTypes>:$in_a,
534+
MIXRShapedOf<QuantDotInTypes>:$in_b,
535+
Optional<MIXRShapedOf<QuantDotScaleTypes>>:$scaleA,
536+
Optional<MIXRShapedOf<QuantDotScaleTypes>>:$scaleB)>,
537+
Results<(outs MIXRShapedOf<QuantDotOutTypes>:$output)> {
539538
let summary = "Dot product of quantized tensors";
540539
let description = [{
541-
The `migraphx.quant_dot` op computes the dot product of two tensors.
540+
The `migraphx.quant_dot` op computes the dot product of two tensors. This operation is used when converting MIGraphX IR's quant_dot operation to MIGraphX Dialect.
541+
Usually `migraphx.quant_dot` operations have QuantizeLinear and DequantizeLinear operations to convert data types from higher precision to lower precision.
542+
`migraphx.quant_dot` operation is also used for "scaled" GEMMs.
542543
}];
544+
let assemblyFormat = [{
545+
$in_a (`scaled` `by` $scaleA^)? `,` $in_b (`scaled` `by` $scaleB^)? attr-dict
546+
`:` type($in_a) (`scaled` `by` type($scaleA)^)? `,` type($in_b)
547+
(`scaled` `by` type($scaleB)^)?
548+
`->` type($output)
549+
}];
550+
let hasVerifier = 1;
543551
}
544552

545-
def MIGraphX_DotOp :
546-
MIGraphX_DotOpBase<"dot", [F32, F16, BF16], [F32, F16, BF16]>{
553+
defvar DotInTypes = [F32, F16, BF16, F4E2M1FN];
554+
defvar DotOutTypes = [F32, F16, BF16];
555+
556+
def MIGraphX_DotOp : MIGraphX_Op<"dot">,
557+
Arguments<(ins MIXRShapedOf<DotInTypes>:$in_a,
558+
MIXRShapedOf<DotInTypes>:$in_b)>,
559+
Results<(outs MIXRShapedOf<DotOutTypes>:$output)> {
547560
let summary = "Dot product of tensors";
548561
let description = [{
549562
The `migraphx.dot` op computes the dot product of two tensors.
550563
}];
564+
let assemblyFormat = [{
565+
$in_a `,` $in_b attr-dict `:` type($in_a) `,` type($in_b) `->` type($output)
566+
}];
551567
}
552568

553569
def MIGraphX_SoftmaxOp :

0 commit comments

Comments
 (0)