diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 599b3b982ec7f..adc27ae6bdafb 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1215,6 +1215,58 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]>, + Arguments<(ins FloatLike:$in, FloatLike:$scale, + OptionalAttr:$fastmath)>, + Results<(outs FloatLike:$out)> { + let summary = "Upcasts input floats using provided scales values following " + "OCP MXFP Spec"; + let description = [{ + This operation upcasts input floating-point values using provided scale + values. It expects both scales and the input operand to be of the same shape, + making the operation elementwise. Scales are usually calculated per block + following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537. + + If scales are calculated per block where blockSize != 1, then scales may + require broadcasting to make this operation elementwise. For example, let's + say the input is of shape ``. Given blockSize != 1 and + assuming quantization happens on the last axis, the input can be reshaped to + ``. Scales will be calculated + per block on the last axis. Therefore, scales will be of shape + ``. Scales could also be of some other + shape as long as it is broadcast compatible with the input, e.g., + `<1 x 1 x ... (dimN/blockSize) x 1>`. + + In this example, before calling into `arith.scaling_extf`, scales must be + broadcasted to ``. Note + that there could be multiple quantization axes. Internally, + `arith.scaling_extf` would perform the following: + + ``` + resultTy = get_type(result) + scaleTy = get_type(scale) + inputTy = get_type(input) + scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0 + scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy + input.extf = arith.extf(input) : inputTy to resultTy + result = arith.mulf(scale.extf, input.extf) + ``` + It propagates NaN values. Therefore, if either scale or the input element + contains NaN, then the output element value will also be a NaN. + }]; + let hasVerifier = 1; + let assemblyFormat = + [{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:` + type($in) `,` type($scale) `to` type($out)}]; +} + //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// @@ -1280,6 +1332,63 @@ def Arith_TruncFOp : attr-dict `:` type($in) `to` type($out) }]; } +//===----------------------------------------------------------------------===// +// Scaling TruncFOp +//===----------------------------------------------------------------------===// + +def Arith_ScalingTruncFOp + : Arith_Op<"scaling_truncf", + [Pure, SameInputOutputTensorDims, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]>, + Arguments<(ins FloatLike:$in, FloatLike:$scale, + OptionalAttr:$roundingmode, + OptionalAttr:$fastmath)>, + Results<(outs FloatLike:$out)> { + let summary = "Downcasts input floating point values using provided scales " + "values following OCP MXFP Spec"; + let description = [{ + This operation downcasts input using the provided scale values. It expects + both scales and the input operand to be of the same shape and, therefore, + makes the operation elementwise. Scales are usually calculated per block + following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537. + Users are required to normalize and clamp the scales as necessary before calling + passing them to this operation. OCP MXFP spec also does the flushing of denorms + on the input operand, which should be handled during lowering by passing appropriate + fastMath flag to this operation. + + If scales are calculated per block where blockSize != 1, scales may require + broadcasting to make this operation elementwise. For example, let's say the + input is of shape ``. Given blockSize != 1 and + assuming quantization happens on the last axis, the input can be reshaped to + ``. Scales will be calculated + per block on the last axis. Therefore, scales will be of shape + ``. Scales could also be of some other + shape as long as it is broadcast compatible with the input, e.g., + `<1 x 1 x ... (dimN/blockSize) x 1>`. + + In this example, before calling into `arith.scaling_truncf`, scales must be + broadcasted to ``. Note + that there could be multiple quantization axes. Internally, + `arith.scaling_truncf` would perform the following: + + ``` + scaleTy = get_type(scale) + inputTy = get_type(input) + resultTy = get_type(result) + scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0 + scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy + result = arith.divf(input, scale.extf) + result.cast = arith.truncf(result, resultTy) + ``` + }]; + let hasVerifier = 1; + let assemblyFormat = + [{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:` + type($in) `,` type($scale) `to` type($out)}]; +} + //===----------------------------------------------------------------------===// // UIToFPOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h index 5aaac8d8e3dc5..e0a4567d6f406 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -62,6 +62,9 @@ void populateExpandBFloat16Patterns(RewritePatternSet &patterns); /// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts. void populateExpandF8E8M0Patterns(RewritePatternSet &patterns); +/// Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops +void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns); + /// Add patterns to expand Arith ops. void populateArithExpandOpsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 3f7b3268dd085..d68dbdb1efeef 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -60,6 +60,7 @@ class Builder { Attribute metadata = Attribute()); // Types. + FloatType getF8E8M0Type(); FloatType getBF16Type(); FloatType getF16Type(); FloatType getTF32Type(); diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 41f2d0f3425e2..9e53e195274aa 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1451,6 +1451,19 @@ bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { LogicalResult arith::ExtFOp::verify() { return verifyExtOp(*this); } +//===----------------------------------------------------------------------===// +// ScalingExtFOp +//===----------------------------------------------------------------------===// + +bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs, + TypeRange outputs) { + return checkWidthChangeCast(inputs.front(), outputs); +} + +LogicalResult arith::ScalingExtFOp::verify() { + return verifyExtOp(*this); +} + //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// @@ -1565,6 +1578,19 @@ LogicalResult arith::TruncFOp::verify() { return verifyTruncateOp(*this); } +//===----------------------------------------------------------------------===// +// ScalingTruncFOp +//===----------------------------------------------------------------------===// + +bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs, + TypeRange outputs) { + return checkWidthChangeCast(inputs.front(), outputs); +} + +LogicalResult arith::ScalingTruncFOp::verify() { + return verifyTruncateOp(*this); +} + //===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index 95546bb09e765..534aff9562b7a 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Arith/Transforms/Passes.h" - #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" @@ -31,7 +31,6 @@ static Value createConst(Location loc, Type type, int value, return rewriter.create( loc, DenseElementsAttr::get(shapedTy, attr)); } - return rewriter.create(loc, attr); } @@ -357,9 +356,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern { f32Bits = b.create(isNan, cF32NaN, f32Bits); Value result = b.create(f32Ty, f32Bits); if (resultETy.getIntOrFloatBitWidth() < 32) { - result = b.create(resultTy, result); + result = b.create(resultTy, result, nullptr, + op.getFastmathAttr()); } else if (resultETy.getIntOrFloatBitWidth() > 32) { - result = b.create(resultTy, result); + result = b.create(resultTy, result, op.getFastmathAttr()); } rewriter.replaceOp(op, result); return success(); @@ -395,9 +395,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern { Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); if (operandETy.getIntOrFloatBitWidth() < 32) { - operand = b.create(f32Ty, operand); + operand = b.create(f32Ty, operand, op.getFastmathAttr()); } else if (operandETy.getIntOrFloatBitWidth() > 32) { - operand = b.create(f32Ty, operand); + operand = b.create( + f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr()); } Value f32Bits = b.create(i32Ty, operand); Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); @@ -409,6 +410,83 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern { } }; +struct ScalingExtFOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ScalingExtFOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value inputOperand = op.getIn(); + Value scaleOperand = op.getScale(); + Type scaleTy = scaleOperand.getType(); + Type scaleETy = getElementTypeOrSelf(scaleOperand); + // allow implicit exponent extraction from 16/32 bits floats + if (scaleETy.getIntOrFloatBitWidth() >= 16) { + scaleETy = b.getF8E8M0Type(); + scaleTy = cloneToShapedType(scaleTy, scaleETy); + scaleOperand = b.create(scaleTy, scaleOperand, nullptr, + op.getFastmathAttr()); + } + if (!llvm::isa(scaleETy)) { + return rewriter.notifyMatchFailure( + op, "scaling_extf is using scales of type which can not be converted " + "to f8E8M0FNU"); + } + Type resultTy = op.getType(); + // extf on scale will essentially create floating point number + // of type resulTy that is 2^scale and will also propagate NaNs + Value scaleExt = + b.create(resultTy, scaleOperand, op.getFastmathAttr()); + Value inputExt = + b.create(resultTy, inputOperand, op.getFastmathAttr()); + Value result = + b.create(inputExt, scaleExt, op.getFastmathAttr()); + rewriter.replaceOp(op, result); + return success(); + } +}; + +/* +Expands arith.ScalingTruncFOp(in, scale) into + scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU + result = arith.truncf(in / (2^scale)) + */ +struct ScalingTruncFOpConverter + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ScalingTruncFOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value inputOperand = op.getIn(); + Value scaleOperand = op.getScale(); + Type scaleTy = scaleOperand.getType(); + Type scaleETy = getElementTypeOrSelf(scaleOperand); + // allow implicit exponent extraction from 16/32 bits floats + if (scaleETy.getIntOrFloatBitWidth() >= 16) { + scaleETy = b.getF8E8M0Type(); + scaleTy = cloneToShapedType(scaleTy, scaleETy); + scaleOperand = b.create(scaleTy, scaleOperand, nullptr, + op.getFastmathAttr()); + } + if (!llvm::isa(scaleETy)) { + return rewriter.notifyMatchFailure( + op, "scaling_truncf is using scales type which can not be converted " + "to f8E8M0FNU"); + } + Type resultTy = op.getType(); + Type inputTy = inputOperand.getType(); + // this will create a floating point number of type + // inputTy that is 2^scale and will also propagate NaNs + scaleOperand = + b.create(inputTy, scaleOperand, op.getFastmathAttr()); + Value result = b.create(inputOperand, scaleOperand, + op.getFastmathAttr()); + Value resultCast = b.create( + resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr()); + rewriter.replaceOp(op, resultCast); + return success(); + } +}; + struct ArithExpandOpsPass : public arith::impl::ArithExpandOpsPassBase { using ArithExpandOpsPassBase::ArithExpandOpsPassBase; @@ -432,7 +510,9 @@ struct ArithExpandOpsPass arith::MaximumFOp, arith::MinimumFOp, arith::MaxNumFOp, - arith::MinNumFOp + arith::MinNumFOp, + arith::ScalingExtFOp, + arith::ScalingTruncFOp >(); if (includeBf16) { @@ -492,8 +572,15 @@ void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) { patterns.getContext()); } +void mlir::arith::populateExpandScalingExtTruncPatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { populateCeilFloorDivExpandOpsPatterns(patterns); + populateExpandScalingExtTruncPatterns(patterns); // clang-format off patterns.add< MaxMinIOpConverter, @@ -503,7 +590,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { MaximumMinimumFOpConverter, MaximumMinimumFOpConverter, MaxNumMinNumFOpConverter, - MaxNumMinNumFOpConverter + MaxNumMinNumFOpConverter >(patterns.getContext()); // clang-format on } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 89102115cdc40..5f7bc50afc418 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -34,6 +34,8 @@ Location Builder::getFusedLoc(ArrayRef locs, Attribute metadata) { // Types. //===----------------------------------------------------------------------===// +FloatType Builder::getF8E8M0Type() { return Float8E8M0FNUType::get(context); } + FloatType Builder::getBF16Type() { return BFloat16Type::get(context); } FloatType Builder::getF16Type() { return Float16Type::get(context); } diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir index 5b6badf13d763..db1349feaff3a 100644 --- a/mlir/test/Dialect/Arith/expand-ops.mlir +++ b/mlir/test/Dialect/Arith/expand-ops.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -verify-diagnostics -split-input-file | FileCheck %s +// RUN: mlir-opt %s -arith-expand -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=SCHECK // Test ceil divide with signed integer // CHECK-LABEL: func @ceildivi @@ -253,7 +254,7 @@ func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU { %0 = arith.truncf %arg0 : f32 to f8E8M0FNU return %0 : f8E8M0FNU } -// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU +// CHECK-LABEL: @truncf_f32_to_f8E8M0FNU // CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32 // CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32 // CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32 @@ -267,7 +268,7 @@ func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU { %0 = arith.truncf %arg0 : f16 to f8E8M0FNU return %0 : f8E8M0FNU } -// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU +// CHECK-LABEL: @truncf_f16_to_f8E8M0FNU // CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32 // CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32 // CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32 @@ -305,9 +306,76 @@ func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf // CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU // CHECK-NOT: arith.truncf +// CHECK: return +// ----- + +func.func @scaling_truncf_f32_to_f4E2M1FN(%arg0 : f32, %arg1: f8E8M0FNU) -> f4E2M1FN { + %0 = arith.scaling_truncf %arg0, %arg1 : f32, f8E8M0FNU to f4E2M1FN + return %0 : f4E2M1FN +} + +// SCHECK-LABEL: @scaling_truncf_f32_to_f4E2M1FN +// SCHECK: %[[SCALEF32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32 +// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF32]] : f32 +// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : f32 to f4E2M1FN +// SCHECK: return %[[RESULT]] + +// ----- + +func.func @scaling_truncf_vector_f16_to_f6E3M2FN(%arg0 : vector<4xf16>, %arg1: vector<4xf8E8M0FNU>) -> vector<4xf6E3M2FN> { + %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf8E8M0FNU> to vector<4xf6E3M2FN> + return %0 : vector<4xf6E3M2FN> +} + +// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f6E3M2FN +// SCHECK: %[[SCALEF16:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16> +// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEF16]] : vector<4xf16> +// SCHECK: %[[RESULT:.+]] = arith.truncf %[[DIVF]] : vector<4xf16> to vector<4xf6E3M2FN> +// SCHECK: return %[[RESULT]] : vector<4xf6E3M2FN> // ----- + +func.func @scaling_truncf_propagate_rounding_mode_fast_math(%arg0 : vector<4xf16>, %arg1: vector<4xf16>) -> vector<4xf6E3M2FN> { + %0 = arith.scaling_truncf %arg0, %arg1 to_nearest_even fastmath : vector<4xf16>, vector<4xf16> to vector<4xf6E3M2FN> + return %0 : vector<4xf6E3M2FN> +} +// SCHECK-LABEL: @scaling_truncf_propagate_rounding_mode_fast_math +// SCHECK: %[[SCALEF8:.+]] = arith.truncf %arg1 fastmath : vector<4xf16> to vector<4xf8E8M0FNU> +// SCHECK: %[[SCALEINTY:.+]] = arith.extf %[[SCALEF8]] fastmath : vector<4xf8E8M0FNU> to vector<4xf16> +// SCHECK: %[[DIVF:.+]] = arith.divf %arg0, %[[SCALEINTY]] fastmath : vector<4xf16> +// SCHECK: %[[TRUNCF:.+]] = arith.truncf [[_:%[a-zA-Z0-9_]+]] to_nearest_even fastmath : vector<4xf16> to vector<4xf6E3M2FN> +// SCHECK: return %[[TRUNCF]] : vector<4xf6E3M2FN> + +// ----- + +func.func @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales(%arg0: f16, %arg1 : f16) -> f4E2M1FN { + %0 = arith.scaling_truncf %arg0, %arg1 : f16, f16 to f4E2M1FN + return %0 : f4E2M1FN +} +// SCHECK-LABEL: @scaling_truncf_f16_to_f4E2M1FN_using_f16_scales +// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : f16 to f8E8M0FN +// SCHECK: return + +// ----- +func.func @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales(%arg0: vector<4xf16>, %arg1 : vector<4xf16>) -> vector<4xf4E2M1FN> { + %0 = arith.scaling_truncf %arg0, %arg1 : vector<4xf16>, vector<4xf16> to vector<4xf4E2M1FN> + return %0 : vector<4xf4E2M1FN> +} +// SCHECK-LABEL: @scaling_truncf_vector_f16_to_f4E2M1FN_using_f16_scales +// SCHECK: %[[SCALETRUNCF:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU> +// SCHECK: return + +// ----- + +func.func @invalid_scaling_truncf_to_f4E2M1FN(%arg0: f16, %arg1 : f8E5M2FNUZ) -> f4E2M1FN { + // expected-error@+1 {{failed to legalize operation 'arith.scaling_truncf' that was explicitly marked illegal}} + %0 = arith.scaling_truncf %arg0, %arg1 : f16, f8E5M2FNUZ to f4E2M1FN + return %0 : f4E2M1FN +} + +// ----- + func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 { %0 = arith.extf %arg0 : f8E8M0FNU to f32 return %0 : f32 @@ -332,7 +400,7 @@ func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 { return %0 : f16 } -// CHECK-LABLE: @extf_f8E8M0FNU_to_f16 +// CHECK-LABEL: @extf_f8E8M0FNU_to_f16 // CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8 // CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8 // CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32 @@ -374,7 +442,109 @@ func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector< // CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16 // CHECK-NOT: arith.extf +// CHECK: return + +// ----- + +func.func @scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E8M0FNU) -> f32 { + %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E8M0FNU to f32 + return %0 : f32 +} + +// SCHECK-LABEL: @scaling_extf_to_f32 +// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : f8E8M0FNU to f32 +// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32 +// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32 +// SCHECK: return %[[RESULT]] + +// ----- + +func.func @scaling_extf_to_f32_using_f16_scales(%arg0: f4E2M1FN, %arg1 : f16) -> f32 { + %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f16 to f32 + return %0 : f32 +} + +// SCHECK-LABEL: @scaling_extf_to_f32_using_f16_scales +// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : f16 to f8E8M0FNU +// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : f8E8M0FNU to f32 +// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : f4E2M1FN to f32 +// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : f32 +// SCHECK: return %[[RESULT]] + +// ----- + +func.func @invalid_scaling_extf_to_f32(%arg0: f4E2M1FN, %arg1 : f8E5M2FNUZ) -> f32 { + // expected-error@+1 {{failed to legalize operation 'arith.scaling_extf' that was explicitly marked illegal}} + %0 = arith.scaling_extf %arg0, %arg1 : f4E2M1FN, f8E5M2FNUZ to f32 + return %0 : f32 +} + +// ----- + +func.func @scaling_extf_vector_to_f32(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf32> { + %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf32> + return %0 : vector<4xf32> +} + +// SCHECK-LABEL: @scaling_extf_vector_to_f32 +// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf32> +// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32> +// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32> +// SCHECK: return %[[RESULT]] + +// ----- + +func.func @scaling_extf_vector_to_f16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xf16> { + %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xf16> + return %0 : vector<4xf16> +} + +// SCHECK-LABEL: @scaling_extf_vector_to_f16 +// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xf16> +// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf16> +// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf16> +// SCHECK: return %[[RESULT]] + +// ----- + +func.func @scaling_extf_vector_to_bf16(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf8E8M0FNU>) -> vector<4xbf16> { + %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf8E8M0FNU> to vector<4xbf16> + return %0 : vector<4xbf16> +} + +// SCHECK-LABEL: @scaling_extf_vector_to_bf16 +// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %arg1 : vector<4xf8E8M0FNU> to vector<4xbf16> +// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xbf16> +// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xbf16> +// SCHECK: return %[[RESULT]] + +// ----- + +func.func @scaling_extf_vector_to_f32_using_f16_scales(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> { + %0 = arith.scaling_extf %arg0, %arg1 : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32> + return %0 : vector<4xf32> +} + +// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales +// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 : vector<4xf16> to vector<4xf8E8M0FNU> +// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] : vector<4xf8E8M0FNU> to vector<4xf32> +// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 : vector<4xf4E2M1FN> to vector<4xf32> +// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] : vector<4xf32> +// SCHECK: return %[[RESULT]] + +// ----- + +func.func @scaling_extf_vector_to_f32_using_f16_scales_fastmath(%arg0: vector<4xf4E2M1FN>, %arg1 : vector<4xf16>) -> vector<4xf32> { + %0 = arith.scaling_extf %arg0, %arg1 fastmath : vector<4xf4E2M1FN>, vector<4xf16> to vector<4xf32> + return %0 : vector<4xf32> +} +// SCHECK-LABEL: @scaling_extf_vector_to_f32_using_f16_scales_fastmath +// SCHECK: %[[TRUNCF_SCALE:.+]] = arith.truncf %arg1 fastmath : vector<4xf16> to vector<4xf8E8M0FNU> +// SCHECK: %[[EXT_SCALE:.+]] = arith.extf %[[TRUNCF_SCALE]] fastmath : vector<4xf8E8M0FNU> to vector<4xf32> +// SCHECK: %[[EXT_INPUT:.+]] = arith.extf %arg0 fastmath : vector<4xf4E2M1FN> to vector<4xf32> +// SCHECK: %[[RESULT:.+]] = arith.mulf %[[EXT_INPUT]], %[[EXT_SCALE]] fastmath : vector<4xf32> +// SCHECK: return %[[RESULT]] // -----