1- // ===- LegalizeForLLVMExport .cpp - Prepare ArmSME for LLVM translation ----===//
1+ // ===- ArmSMEToLLVM .cpp - Convert ArmSME to LLVM dialect -------------- ----===//
22//
33// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66//
77// ===----------------------------------------------------------------------===//
8+ //
9+ // This file implements lowering of ArmSME operations to LLVM intrinsics.
10+ //
11+ // ===----------------------------------------------------------------------===//
12+
13+ #include " mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
814
915#include " mlir/Conversion/LLVMCommon/ConversionTarget.h"
1016#include " mlir/Conversion/LLVMCommon/Pattern.h"
1117#include " mlir/Dialect/Arith/IR/Arith.h"
1218#include " mlir/Dialect/ArmSME/IR/ArmSME.h"
13- #include " mlir/Dialect/ArmSME/Transforms/Transforms.h"
1419#include " mlir/Dialect/ArmSME/Utils/Utils.h"
1520#include " mlir/Dialect/Func/IR/FuncOps.h"
1621#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
17- #include " mlir/Dialect/SCF/IR/SCF.h"
1822#include " mlir/Dialect/Vector/IR/VectorOps.h"
23+ #include " mlir/Pass/Pass.h"
24+ #include " mlir/Transforms/DialectConversion.h"
25+
26+ namespace mlir {
27+ #define GEN_PASS_DEF_CONVERTARMSMETOLLVM
28+ #include " mlir/Conversion/Passes.h.inc"
29+ } // namespace mlir
1930
2031using namespace mlir ;
21- using namespace mlir ::arm_sme;
2232
2333namespace {
2434
@@ -40,11 +50,11 @@ namespace {
4050// / The 'arm_sme.cast_tile_to_vector' (which models the return) and the
4151// / 'arith.shli' (which generates the mask) will be folded away after tile
4252// / allocation and canonization.
43- struct ZeroOpConversion : public ConvertOpToLLVMPattern <ZeroOp> {
44- using ConvertOpToLLVMPattern<ZeroOp>::ConvertOpToLLVMPattern;
53+ struct ZeroOpConversion : public ConvertOpToLLVMPattern <arm_sme:: ZeroOp> {
54+ using ConvertOpToLLVMPattern<arm_sme:: ZeroOp>::ConvertOpToLLVMPattern;
4555
4656 LogicalResult
47- matchAndRewrite (ZeroOp zero, OpAdaptor adaptor,
57+ matchAndRewrite (arm_sme:: ZeroOp zero, OpAdaptor adaptor,
4858 ConversionPatternRewriter &rewriter) const override {
4959 auto loc = zero.getLoc ();
5060
@@ -121,7 +131,7 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<ZeroOp> {
121131};
122132
123133// / Lower `arm_sme.load_tile_slice` to SME intrinsics.
124- struct LoadTileSliceToArmSMELowering
134+ struct LoadTileSliceConversion
125135 : public ConvertOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
126136 using ConvertOpToLLVMPattern<
127137 arm_sme::LoadTileSliceOp>::ConvertOpToLLVMPattern;
@@ -220,7 +230,7 @@ struct LoadTileSliceToArmSMELowering
220230};
221231
222232// / Lower for `arm_sme.store_tile_slice` to SME intrinsics.
223- struct StoreTileSliceToArmSMELowering
233+ struct StoreTileSliceConversion
224234 : public ConvertOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
225235 using ConvertOpToLLVMPattern<
226236 arm_sme::StoreTileSliceOp>::ConvertOpToLLVMPattern;
@@ -313,7 +323,7 @@ struct StoreTileSliceToArmSMELowering
313323};
314324
315325// / Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
316- struct MoveVectorToTileSliceToArmSMELowering
326+ struct MoveVectorToTileSliceConversion
317327 : public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
318328 using ConvertOpToLLVMPattern<
319329 arm_sme::MoveVectorToTileSliceOp>::ConvertOpToLLVMPattern;
@@ -373,7 +383,7 @@ struct MoveVectorToTileSliceToArmSMELowering
373383};
374384
375385// / Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
376- struct MoveTileSliceToVectorArmSMELowering
386+ struct MoveTileSliceToVectorConversion
377387 : public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
378388 using ConvertOpToLLVMPattern<
379389 arm_sme::MoveTileSliceToVectorOp>::ConvertOpToLLVMPattern;
@@ -456,7 +466,8 @@ struct OuterProductOpConversion
456466 // * half-precision - +sme2p1,+b16b16
457467 //
458468 // It should be possible to control lowering based on target features.
459- // [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
469+ // [1]
470+ // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
460471 if ((vectorType.getRank () != 2 ) || !vectorType.allDimsScalable ())
461472 return false ;
462473
@@ -475,7 +486,7 @@ struct OuterProductOpConversion
475486 };
476487
477488 // TODO: Support CombiningKind::Sub for outer products.
478- if (outerProductOp.getKind () != CombiningKind::Add)
489+ if (outerProductOp.getKind () != arm_sme:: CombiningKind::Add)
479490 return outerProductOp.emitError (" unsupported kind" );
480491
481492 auto resultVectorType = outerProductOp.getResultType ();
@@ -522,32 +533,56 @@ struct OuterProductOpConversion
522533
523534} // namespace
524535
525- void mlir::configureArmSMELegalizeForExportTarget (
526- LLVMConversionTarget &target) {
536+ namespace {
537+
538+ struct ConvertArmSMEToLLVMPass
539+ : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
540+ void runOnOperation () override {
541+ LLVMConversionTarget target (getContext ());
542+ RewritePatternSet patterns (&getContext ());
543+ ArmSMETypeConverter converter (&getContext (),
544+ LowerToLLVMOptions (&getContext ()));
545+
546+ configureArmSMEToLLVMConversionLegality (target);
547+ populateArmSMEToLLVMConversionPatterns (converter, patterns);
548+
549+ if (failed (applyPartialConversion (getOperation (), target,
550+ std::move (patterns))))
551+ signalPassFailure ();
552+ }
553+ };
554+
555+ } // namespace
556+
557+ void mlir::configureArmSMEToLLVMConversionLegality (ConversionTarget &target) {
558+ target.addIllegalDialect <arm_sme::ArmSMEDialect>();
527559 target.addLegalOp <
528- scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector,
529- arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero,
530- arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
531- arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
532- arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
533- arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
534- arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
535- arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
536- arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
537- arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
538- arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
539- arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
540- arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
541- arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
542- arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
543- target.addLegalOp <GetTileID>();
544- target.addIllegalOp <vector::OuterProductOp>();
560+ arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile,
561+ arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
562+ arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
563+ arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
564+ arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
565+ arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
566+ arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
567+ arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
568+ arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
569+ arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
570+ arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
571+ arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
572+ arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
573+ arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
574+ arm_sme::aarch64_sme_mopa>();
575+ target.addLegalDialect <arith::ArithDialect>();
576+ target.addLegalOp <UnrealizedConversionCastOp>();
577+ }
578+
579+ void mlir::populateArmSMEToLLVMConversionPatterns (
580+ ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
581+ patterns.add <LoadTileSliceConversion, MoveTileSliceToVectorConversion,
582+ MoveVectorToTileSliceConversion, StoreTileSliceConversion,
583+ OuterProductOpConversion, ZeroOpConversion>(converter);
545584}
546585
547- void mlir::populateArmSMELegalizeForLLVMExportPatterns (
548- LLVMTypeConverter &converter, RewritePatternSet &patterns) {
549- patterns.add <
550- LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
551- MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
552- OuterProductOpConversion, ZeroOpConversion>(converter);
586+ std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass () {
587+ return std::make_unique<ConvertArmSMEToLLVMPass>();
553588}
0 commit comments