1414#include " mlir/Dialect/Arith/Utils/Utils.h"
1515#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
1616#include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
17+ #include " mlir/Dialect/Utils/IndexingUtils.h"
1718#include " mlir/Dialect/Vector/IR/VectorOps.h"
19+ #include " mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
20+ #include " mlir/Dialect/Vector/Utils/VectorUtils.h"
1821#include " mlir/IR/BuiltinTypes.h"
1922#include " mlir/IR/PatternMatch.h"
2023#include " mlir/IR/TypeUtilities.h"
@@ -32,6 +35,7 @@ using namespace mlir::amdgpu;
3235namespace {
3336// Define commonly used chipsets versions for convenience.
3437constexpr Chipset kGfx942 = Chipset(9 , 4 , 2 );
38+ constexpr Chipset kGfx950 = Chipset(9 , 5 , 0 );
3539
3640struct ArithToAMDGPUConversionPass final
3741 : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
@@ -73,6 +77,28 @@ struct TruncfToFloat16RewritePattern final
7377 PatternRewriter &rewriter) const override ;
7478};
7579
80+ struct ScalingExtFRewritePattern final
81+ : OpRewritePattern<arith::ScalingExtFOp> {
82+ using OpRewritePattern::OpRewritePattern;
83+
84+ ScalingExtFRewritePattern (MLIRContext *ctx)
85+ : OpRewritePattern::OpRewritePattern(ctx) {}
86+
87+ LogicalResult matchAndRewrite (arith::ScalingExtFOp op,
88+ PatternRewriter &rewriter) const override ;
89+ };
90+
91+ struct ScalingTruncFRewritePattern final
92+ : OpRewritePattern<arith::ScalingTruncFOp> {
93+ using OpRewritePattern::OpRewritePattern;
94+
95+ ScalingTruncFRewritePattern (MLIRContext *ctx)
96+ : OpRewritePattern::OpRewritePattern(ctx) {}
97+
98+ LogicalResult matchAndRewrite (arith::ScalingTruncFOp op,
99+ PatternRewriter &rewriter) const override ;
100+ };
101+
76102} // end namespace
77103
78104static bool isSupportedF8 (Type elementType, Chipset chipset) {
@@ -395,6 +421,247 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
395421 return success ();
396422}
397423
424+ // / Get the broadcasted / splatted value for a chain of ops.
425+ static Value getOriginalVectorValue (Value value) {
426+ Value current = value;
427+ while (Operation *definingOp = current.getDefiningOp ()) {
428+ bool skipOp = llvm::TypeSwitch<Operation *, bool >(definingOp)
429+ .Case <vector::ShapeCastOp>([¤t](auto op) {
430+ current = op.getSource ();
431+ return true ;
432+ })
433+ .Case <vector::BroadcastOp>([¤t](auto op) {
434+ current = op.getSource ();
435+ return false ;
436+ })
437+ .Case <vector::SplatOp>([¤t](auto op) {
438+ current = op.getInput ();
439+ return false ;
440+ })
441+ .Default ([](Operation *) { return false ; });
442+
443+ if (!skipOp) {
444+ break ;
445+ }
446+ }
447+ return current;
448+ }
449+
450+ LogicalResult
451+ ScalingExtFRewritePattern::matchAndRewrite (arith::ScalingExtFOp op,
452+ PatternRewriter &rewriter) const {
453+ Location loc = op.getLoc ();
454+ constexpr int64_t opWidth = 2 ;
455+
456+ Value in = op.getIn ();
457+ Value scale = op.getScale ();
458+ Value out = op.getOut ();
459+
460+ Type f32 = rewriter.getF32Type ();
461+ Type inType = getElementTypeOrSelf (in);
462+ Type scaleType = getElementTypeOrSelf (scale);
463+ Type outType = getElementTypeOrSelf (out);
464+
465+ VectorType outVecType = dyn_cast<VectorType>(out.getType ());
466+ VectorType scaleVecType = dyn_cast<VectorType>(scale.getType ());
467+
468+ if (outVecType && outVecType.isScalable ())
469+ return failure ();
470+
471+ Type scaleF32Type =
472+ scaleVecType ? VectorType::get (scaleVecType.getShape (), f32 ) : f32 ;
473+ if (scaleType.getIntOrFloatBitWidth () < 32 )
474+ scale = rewriter.create <arith::ExtFOp>(loc, scaleF32Type, scale);
475+ else if (scaleType.getIntOrFloatBitWidth () > 32 )
476+ scale = rewriter.create <arith::TruncFOp>(loc, scaleF32Type, scale);
477+
478+ VectorType extScaleResultType = VectorType::get (opWidth, outType);
479+
480+ if (!outVecType) {
481+ Value inCast =
482+ rewriter.create <vector::SplatOp>(loc, VectorType::get (1 , inType), in);
483+ // TODO: replace this with non-packed ScaledExtOp
484+ Value scaleExt = rewriter.create <amdgpu::ScaledExtPackedOp>(
485+ loc, extScaleResultType, inCast, scale, 0 );
486+ scaleExt = rewriter.replaceOpWithNewOp <vector::ExtractOp>(op, scaleExt, 0 );
487+ return success ();
488+ }
489+
490+ VectorType inVecType = cast<VectorType>(in.getType ());
491+ Value origScale = getOriginalVectorValue (op.getScale ());
492+
493+ ArrayRef<int64_t > inShape = inVecType.getShape ();
494+ SmallVector<int64_t > originalScaleShape;
495+ if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType ()))
496+ llvm::append_range (originalScaleShape, origScaleVecType.getShape ());
497+
498+ originalScaleShape.insert (originalScaleShape.end (),
499+ inShape.size () - originalScaleShape.size (), 1 );
500+
501+ auto maybeRatio = computeShapeRatio (inShape, originalScaleShape);
502+ assert (maybeRatio &&
503+ " failed to derive block size from broadcast or splat operation" );
504+
505+ SmallVector<int64_t > ratio =
506+ maybeRatio.value_or (SmallVector<int64_t >(inShape.size (), 1 ));
507+
508+ int64_t blockSize = computeProduct (ratio);
509+
510+ Value zero = rewriter.create <arith::ConstantOp>(
511+ loc, outType, rewriter.getFloatAttr (outType, 0.0 ));
512+ Value result = rewriter.createOrFold <vector::SplatOp>(loc, outVecType, zero);
513+
514+ for (SmallVector<int64_t > offsets : StaticTileOffsetRange (inShape, ratio)) {
515+ SmallVector<int64_t > strides (offsets.size (), 1 );
516+ Value block = rewriter.create <vector::ExtractStridedSliceOp>(
517+ loc, in, offsets, ratio, strides);
518+ VectorType block1DType = VectorType::get (blockSize, inType);
519+ Value block1D =
520+ rewriter.create <vector::ShapeCastOp>(loc, block1DType, block);
521+ Value uniformScale =
522+ rewriter.create <vector::ExtractOp>(loc, scale, offsets);
523+
524+ VectorType blockResultType = VectorType::get (blockSize, outType);
525+ Value blockResult =
526+ rewriter.createOrFold <vector::SplatOp>(loc, blockResultType, zero);
527+
528+ for (int64_t i = 0 , sliceWidth = std::min (opWidth, blockSize - i);
529+ i < blockSize;
530+ i += sliceWidth, sliceWidth = std::min (opWidth, blockSize - i)) {
531+ Value slice = rewriter.create <vector::ExtractStridedSliceOp>(
532+ loc, block1D, i, sliceWidth, 1 );
533+ // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
534+ Value scaleExt = rewriter.create <amdgpu::ScaledExtPackedOp>(
535+ loc, extScaleResultType, slice, uniformScale, 0 );
536+ if (sliceWidth != opWidth)
537+ scaleExt = rewriter.create <vector::ExtractStridedSliceOp>(
538+ loc, scaleExt, 0 , sliceWidth, 1 );
539+ blockResult = rewriter.create <vector::InsertStridedSliceOp>(
540+ loc, scaleExt, blockResult, i, 1 );
541+ }
542+
543+ VectorType resultType = VectorType::get (ratio, outType);
544+ Value cast =
545+ rewriter.create <vector::ShapeCastOp>(loc, resultType, blockResult);
546+ result = rewriter.create <vector::InsertStridedSliceOp>(loc, cast, result,
547+ offsets, strides);
548+ }
549+
550+ rewriter.replaceOp (op, result);
551+
552+ return success ();
553+ }
554+
555+ LogicalResult
556+ ScalingTruncFRewritePattern::matchAndRewrite (arith::ScalingTruncFOp op,
557+ PatternRewriter &rewriter) const {
558+ Location loc = op.getLoc ();
559+ constexpr int64_t opWidth = 2 ;
560+
561+ Value in = op.getIn ();
562+ Value scale = op.getScale ();
563+ Value out = op.getOut ();
564+
565+ Type f32 = rewriter.getF32Type ();
566+ Type inType = getElementTypeOrSelf (in);
567+ Type scaleType = getElementTypeOrSelf (scale);
568+ Type outType = getElementTypeOrSelf (out);
569+
570+ VectorType outVecType = dyn_cast<VectorType>(out.getType ());
571+ VectorType scaleVecType = dyn_cast<VectorType>(scale.getType ());
572+
573+ if (outVecType && outVecType.isScalable ())
574+ return failure ();
575+
576+ Type scaleF32Type =
577+ scaleVecType ? VectorType::get (scaleVecType.getShape (), f32 ) : f32 ;
578+ if (scaleType.getIntOrFloatBitWidth () < 32 )
579+ scale = rewriter.create <arith::ExtFOp>(loc, scaleF32Type, scale);
580+ else if (scaleType.getIntOrFloatBitWidth () > 32 )
581+ scale = rewriter.create <arith::TruncFOp>(loc, scaleF32Type, scale);
582+
583+ Value zero = rewriter.create <arith::ConstantOp>(
584+ loc, outType, rewriter.getFloatAttr (outType, 0.0 ));
585+ unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth ();
586+ VectorType truncScaleResultType = VectorType::get (numPackedElem, outType);
587+
588+ if (!outVecType) {
589+ Type inVecType = VectorType::get (1 , inType);
590+ Value inCast = rewriter.create <vector::SplatOp>(loc, inVecType, in);
591+ // TODO: replace this with non-packed ScaledTruncOp
592+ Value scaleTrunc = rewriter.create <amdgpu::PackedScaledTruncOp>(
593+ loc, truncScaleResultType, inCast, scale, 0 , /* existing=*/ nullptr );
594+ scaleTrunc =
595+ rewriter.replaceOpWithNewOp <vector::ExtractOp>(op, scaleTrunc, 0 );
596+ return success ();
597+ }
598+
599+ VectorType inVecType = cast<VectorType>(in.getType ());
600+ Value origScale = getOriginalVectorValue (op.getScale ());
601+
602+ ArrayRef<int64_t > inShape = inVecType.getShape ();
603+ SmallVector<int64_t > originalScaleShape;
604+ if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType ()))
605+ llvm::append_range (originalScaleShape, origScaleVecType.getShape ());
606+
607+ originalScaleShape.insert (originalScaleShape.end (),
608+ inShape.size () - originalScaleShape.size (), 1 );
609+
610+ auto maybeRatio = computeShapeRatio (inShape, originalScaleShape);
611+ assert (maybeRatio &&
612+ " failed to derive block size from broadcast or splat operation" );
613+
614+ SmallVector<int64_t > ratio =
615+ maybeRatio.value_or (SmallVector<int64_t >(inShape.size (), 1 ));
616+
617+ int64_t blockSize = computeProduct (ratio);
618+
619+ Value result = rewriter.createOrFold <vector::SplatOp>(loc, outVecType, zero);
620+
621+ for (SmallVector<int64_t > offsets : StaticTileOffsetRange (inShape, ratio)) {
622+ SmallVector<int64_t > strides (offsets.size (), 1 );
623+ Value block = rewriter.create <vector::ExtractStridedSliceOp>(
624+ loc, in, offsets, ratio, strides);
625+ VectorType block1DType = VectorType::get (blockSize, inType);
626+ Value block1D =
627+ rewriter.create <vector::ShapeCastOp>(loc, block1DType, block);
628+ Value uniformScale =
629+ rewriter.create <vector::ExtractOp>(loc, scale, offsets);
630+
631+ VectorType blockResultType = VectorType::get (blockSize, outType);
632+ Value blockResult =
633+ rewriter.createOrFold <vector::SplatOp>(loc, blockResultType, zero);
634+
635+ for (int64_t i = 0 , sliceWidth = std::min (opWidth, blockSize - i);
636+ i < blockSize;
637+ i += sliceWidth, sliceWidth = std::min (opWidth, blockSize - i)) {
638+ Value slice = rewriter.create <vector::ExtractStridedSliceOp>(
639+ loc, block1D, i, sliceWidth, 1 );
640+ // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
641+ Value scaleTrunc = rewriter.create <amdgpu::PackedScaledTruncOp>(
642+ loc, truncScaleResultType, slice, uniformScale, 0 ,
643+ /* existing=*/ nullptr );
644+ int64_t packedWidth =
645+ cast<VectorType>(scaleTrunc.getType ()).getNumElements ();
646+ if (packedWidth != opWidth)
647+ scaleTrunc = rewriter.create <vector::ExtractStridedSliceOp>(
648+ loc, scaleTrunc, 0 , sliceWidth, 1 );
649+ blockResult = rewriter.create <vector::InsertStridedSliceOp>(
650+ loc, scaleTrunc, blockResult, i, 1 );
651+ }
652+
653+ VectorType resultType = VectorType::get (ratio, outType);
654+ Value cast =
655+ rewriter.create <vector::ShapeCastOp>(loc, resultType, blockResult);
656+ result = rewriter.create <vector::InsertStridedSliceOp>(loc, cast, result,
657+ offsets, strides);
658+ }
659+
660+ rewriter.replaceOp (op, result);
661+
662+ return success ();
663+ }
664+
398665void mlir::arith::populateArithToAMDGPUConversionPatterns (
399666 RewritePatternSet &patterns, bool convertFP8Arithmetic,
400667 bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
@@ -406,6 +673,11 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
406673 }
407674 if (allowPackedF16Rtz)
408675 patterns.add <TruncfToFloat16RewritePattern>(patterns.getContext ());
676+
677+ if (chipset >= kGfx950 ) {
678+ patterns.add <ScalingExtFRewritePattern>(patterns.getContext ());
679+ patterns.add <ScalingTruncFRewritePattern>(patterns.getContext ());
680+ }
409681}
410682
411683void ArithToAMDGPUConversionPass::runOnOperation () {
0 commit comments