-
Couldn't load subscription status.
- Fork 15k
[MLIR][WASM] Introduce the RaiseWasmMLIRPass to lower WasmSSA MLIR to core dialects #164562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[MLIR][WASM] Introduce the RaiseWasmMLIRPass to lower WasmSSA MLIR to core dialects #164562
Conversation
…R to lower level MLIR dialects Introduce the structure of the pass along with conversion of wasm.func, wasm.call and local.get lowerings for simple functions handling. -- Co-authored-by: Luc Forget <[email protected]> Co-authored-by: Jessica Paquette <[email protected]>
|
@llvm/pr-subscribers-mlir Author: Ferdinand Lemaire (flemairen6) ChangesThis is following #154674 and still related to https://discourse.llvm.org/t/rfc-mlir-dialect-for-webassembly/86758. This PR introduces the RaiseWasmMLIRPass. This pass lowers WasmSSA MLIR to other dialects of the LLVM ecosystem (namely: arith, math, cf and memref). Patch is 95.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164562.diff 38 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 40d866ec7bf10..664bcb00ab45a 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -58,6 +58,7 @@
#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
+#include "mlir/Conversion/RaiseWasm/RaiseWasmMLIR.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 70e3e45c225db..860e474a067a4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1582,6 +1582,19 @@ def ConvertVectorToAMX : Pass<"convert-vector-to-amx"> {
];
}
+//===----------------------------------------------------------------------===//
+// RaiseWasmMLIR
+//===----------------------------------------------------------------------===//
+
+def RaiseWasmMLIR : Pass<"raise-wasm-mlir"> {
+ let summary = "Convert Wasm dialect to a group of dialect as a bridge to LLVM MLIR conversion";
+ let dependentDialects = [
+ "func::FuncDialect", "arith::ArithDialect", "cf::ControlFlowDialect",
+ "memref::MemRefDialect", "vector::VectorDialect", "wasmssa::WasmSSADialect",
+ "math::MathDialect"
+ ];
+}
+
//===----------------------------------------------------------------------===//
// XeVMToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/RaiseWasm/RaiseWasmMLIR.h b/mlir/include/mlir/Conversion/RaiseWasm/RaiseWasmMLIR.h
new file mode 100644
index 0000000000000..a54fc45b5d048
--- /dev/null
+++ b/mlir/include/mlir/Conversion/RaiseWasm/RaiseWasmMLIR.h
@@ -0,0 +1,30 @@
+//===- RaiseWasmMLIR.h - Convert wasm to standard dialects ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_RAISEWASM_RAISEWASMMLIR_H
+#define MLIR_CONVERSION_RAISEWASM_RAISEWASMMLIR_H
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+class Pass;
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_RAISEWASMMLIR
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Collect a set of patterns to convert from the Wasm dialect to standard dialects.
+void populateRaiseWasmMLIRConversionPatterns(TypeConverter&, RewritePatternSet &);
+
+/// Create a pass to convert ops from WasmDialect to standard dialects.
+std::unique_ptr<Pass> createRaiseWasmMLIRPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_RAISEWASM_RAISEWASMMLIR_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index bebf1b8fff3f9..c43b5f3ad5489 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -52,6 +52,7 @@ add_subdirectory(OpenACCToSCF)
add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
add_subdirectory(PtrToLLVM)
+add_subdirectory(RaiseWasm)
add_subdirectory(ReconcileUnrealizedCasts)
add_subdirectory(SCFToControlFlow)
add_subdirectory(SCFToEmitC)
diff --git a/mlir/lib/Conversion/RaiseWasm/CMakeLists.txt b/mlir/lib/Conversion/RaiseWasm/CMakeLists.txt
new file mode 100644
index 0000000000000..43b5fd79e49df
--- /dev/null
+++ b/mlir/lib/Conversion/RaiseWasm/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRWasmRaise
+ RaiseWasmMLIR.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/RaiseWasm
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIRControlFlowDialect
+ MLIRFuncDialect
+ MLIRMathDialect
+ MLIRMemRefDialect
+ MLIRTransforms
+ MLIRVectorDialect
+ MLIRWasmSSADialect
+ )
diff --git a/mlir/lib/Conversion/RaiseWasm/RaiseWasmMLIR.cpp b/mlir/lib/Conversion/RaiseWasm/RaiseWasmMLIR.cpp
new file mode 100644
index 0000000000000..d67572f3e67ba
--- /dev/null
+++ b/mlir/lib/Conversion/RaiseWasm/RaiseWasmMLIR.cpp
@@ -0,0 +1,468 @@
+//===- RaiseWasmMLIR.cpp - Convert Wasm to less abstract dialects ---*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering of wasm operations to standard dialects ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/RaiseWasm/RaiseWasmMLIR.h"
+
+
+#include "llvm/Support/LogicalResult.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+#include <optional>
+
+#define DEBUG_TYPE "wasm-convert"
+
+namespace mlir {
+#define GEN_PASS_DEF_RAISEWASMMLIR
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::wasmssa;
+
+namespace {
+
+template <typename SourceOp, typename TargetIntOp, typename TargetFPOp>
+struct IntFPDispatchMappingConversion : OpConversionPattern<SourceOp> {
+ using OpConversionPattern<SourceOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(SourceOp srcOp, typename SourceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Type type = srcOp.getRhs().getType();
+ if (type.isInteger()) {
+ rewriter.replaceOpWithNewOp<TargetIntOp>(srcOp, srcOp->getResultTypes(),
+ adaptor.getOperands());
+ return success();
+ }
+ if (!type.isFloat())
+ return failure();
+ rewriter.replaceOpWithNewOp<TargetFPOp>(srcOp, srcOp->getResultTypes(),
+ adaptor.getOperands());
+ return success();
+ }
+};
+
+using WasmAddOpConversion =
+ IntFPDispatchMappingConversion<AddOp, arith::AddIOp, arith::AddFOp>;
+using WasmMulOpConversion =
+ IntFPDispatchMappingConversion<MulOp, arith::MulIOp, arith::MulFOp>;
+using WasmSubOpConversion =
+ IntFPDispatchMappingConversion<SubOp, arith::SubIOp, arith::SubFOp>;
+
+/// Convert a k-ary source operation \p SourceOp into an operation \p TargetOp.
+/// Both \p SourceOp and \p TargetOp must have the same number of operands.
+template <typename SourceOp, typename TargetOp>
+struct OpMappingConversion : OpConversionPattern<SourceOp> {
+ using OpConversionPattern<SourceOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(SourceOp srcOp, typename SourceOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<TargetOp>(srcOp, srcOp->getResultTypes(),
+ adaptor.getOperands());
+ return success();
+ }
+};
+
+using WasmAndOpConversion = OpMappingConversion<AndOp, arith::AndIOp>;
+using WasmCeilOpConversion = OpMappingConversion<CeilOp, math::CeilOp>;
+/// TODO: SIToFP and UIToFP don't allow specification of the floating point
+/// rounding mode
+using WasmConvertSOpConversion =
+ OpMappingConversion<ConvertSOp, arith::SIToFPOp>;
+using WasmConvertUOpConversion =
+ OpMappingConversion<ConvertUOp, arith::UIToFPOp>;
+using WasmDemoteOpConversion = OpMappingConversion<DemoteOp, arith::TruncFOp>;
+using WasmDivFPOpConversion = OpMappingConversion<DivOp, arith::DivFOp>;
+using WasmDivSIOpConversion = OpMappingConversion<DivSIOp, arith::DivSIOp>;
+using WasmDivUIOpConversion = OpMappingConversion<DivUIOp, arith::DivUIOp>;
+using WasmExtendSOpConversion =
+ OpMappingConversion<ExtendSI32Op, arith::ExtSIOp>;
+using WasmExtendUOpConversion =
+ OpMappingConversion<ExtendUI32Op, arith::ExtUIOp>;
+using WasmFloorOpConversion = OpMappingConversion<FloorOp, math::FloorOp>;
+using WasmMaxOpConversion = OpMappingConversion<MaxOp, arith::MaximumFOp>;
+using WasmMinOpConversion = OpMappingConversion<MinOp, arith::MinimumFOp>;
+using WasmOrOpConversion = OpMappingConversion<OrOp, arith::OrIOp>;
+using WasmPromoteOpConversion = OpMappingConversion<PromoteOp, arith::ExtFOp>;
+using WasmRemSIOpConversion = OpMappingConversion<RemSIOp, arith::RemSIOp>;
+using WasmRemUIOpConversion = OpMappingConversion<RemUIOp, arith::RemUIOp>;
+using WasmReinterpretOpConversion =
+ OpMappingConversion<ReinterpretOp, arith::BitcastOp>;
+using WasmShLOpConversion = OpMappingConversion<ShLOp, arith::ShLIOp>;
+using WasmShRSOpConversion = OpMappingConversion<ShRSOp, arith::ShRSIOp>;
+using WasmShRUOpConversion = OpMappingConversion<ShRUOp, arith::ShRUIOp>;
+using WasmXOrOpConversion = OpMappingConversion<XOrOp, arith::XOrIOp>;
+using WasmNegOpConversion = OpMappingConversion<NegOp, arith::NegFOp>;
+using WasmCopySignOpConversion =
+ OpMappingConversion<CopySignOp, math::CopySignOp>;
+using WasmClzOpConversion =
+ OpMappingConversion<ClzOp, math::CountLeadingZerosOp>;
+using WasmCtzOpConversion =
+ OpMappingConversion<CtzOp, math::CountTrailingZerosOp>;
+using WasmPopCntOpConversion = OpMappingConversion<PopCntOp, math::CtPopOp>;
+using WasmAbsOpConversion = OpMappingConversion<AbsOp, math::AbsFOp>;
+using WasmTruncOpConversion = OpMappingConversion<TruncOp, math::TruncOp>;
+using WasmSqrtOpConversion = OpMappingConversion<SqrtOp, math::SqrtOp>;
+using WasmWrapOpConversion = OpMappingConversion<WrapOp, arith::TruncIOp>;
+
+struct WasmCallOpConversion : OpConversionPattern<FuncCallOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(FuncCallOp funcCallOp, FuncCallOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<func::CallOp>(
+ funcCallOp, funcCallOp.getCallee(), funcCallOp.getResults().getTypes(),
+ funcCallOp.getOperands());
+ return success();
+ }
+};
+
+struct WasmConstOpConversion : OpConversionPattern<ConstOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ConstOp constOp, ConstOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(constOp, constOp.getValue());
+ return success();
+ }
+};
+
+struct WasmFuncImportOpConversion : OpConversionPattern<FuncImportOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(FuncImportOp funcImportOp, FuncImportOp::Adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto nFunc = rewriter.replaceOpWithNewOp<func::FuncOp>(
+ funcImportOp, funcImportOp.getSymName(), funcImportOp.getType());
+ nFunc.setVisibility(SymbolTable::Visibility::Private);
+ return success();
+ }
+};
+
+struct WasmFuncOpConversion : OpConversionPattern<FuncOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(FuncOp funcOp, FuncOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto newFunc = rewriter.create<func::FuncOp>(
+ funcOp->getLoc(), funcOp.getSymName(), funcOp.getFunctionType());
+ rewriter.cloneRegionBefore(funcOp.getBody(), newFunc.getBody(),
+ newFunc.getBody().end());
+ Block *oldEntryBlock = &newFunc.getBody().front();
+ auto blockArgTypes = oldEntryBlock->getArgumentTypes();
+ TypeConverter::SignatureConversion sC{oldEntryBlock->getNumArguments()};
+ auto numArgs = blockArgTypes.size();
+ for (size_t i = 0; i < numArgs; ++i) {
+ auto argType = dyn_cast<LocalRefType>(blockArgTypes[i]);
+ if (!argType)
+ return failure();
+ sC.addInputs(i, argType.getElementType());
+ }
+
+ rewriter.applySignatureConversion(oldEntryBlock, sC, getTypeConverter());
+ rewriter.replaceOp(funcOp, newFunc);
+ return success();
+ }
+};
+
+struct WasmGlobalImportOpConverter : OpConversionPattern<GlobalImportOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(GlobalImportOp gIOp, GlobalImportOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto memrefGOp = rewriter.replaceOpWithNewOp<memref::GlobalOp>(
+ gIOp, gIOp.getSymNameAttr(), rewriter.getStringAttr("nested"),
+ TypeAttr::get(MemRefType::get({1}, gIOp.getType())), Attribute{},
+ /*constant*/ UnitAttr{},
+ /*alignment*/ IntegerAttr{});
+ memrefGOp.setConstant(!gIOp.getIsMutable());
+ return success();
+ }
+};
+
+template <typename CRTP, typename OriginOpType>
+struct GlobalOpConverter : OpConversionPattern<GlobalOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(GlobalOp globalOp, GlobalOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ ReturnOp rop;
+ globalOp->walk([&rop](ReturnOp op) { rop = op; });
+
+ if (rop->getNumOperands() != 1)
+ return rewriter.notifyMatchFailure(
+ globalOp, "globalOp initializer should return one value exactly");
+
+ auto initializerOp =
+ dyn_cast<OriginOpType>(rop->getOperand(0).getDefiningOp());
+
+ if (!initializerOp)
+ return rewriter.notifyMatchFailure(
+ globalOp, "invalid initializer op type for this pattern");
+
+ return static_cast<CRTP const *>(this)->handleInitializer(
+ globalOp, rewriter, initializerOp);
+ }
+};
+
+struct WasmGlobalWithConstInitConversion
+ : GlobalOpConverter<WasmGlobalWithConstInitConversion, ConstOp> {
+ using GlobalOpConverter::GlobalOpConverter;
+ LogicalResult handleInitializer(GlobalOp globalOp,
+ ConversionPatternRewriter &rewriter,
+ ConstOp constInit) const {
+ auto initializer =
+ DenseElementsAttr::get(RankedTensorType::get({1}, globalOp.getType()),
+ ArrayRef<Attribute>{constInit.getValueAttr()});
+ auto globalReplacement = rewriter.replaceOpWithNewOp<memref::GlobalOp>(
+ globalOp, globalOp.getSymNameAttr(), rewriter.getStringAttr("private"),
+ TypeAttr::get(MemRefType::get({1}, globalOp.getType())), initializer,
+ /*constant*/ UnitAttr{},
+ /*alignment*/ IntegerAttr{});
+ globalReplacement.setConstant(!globalOp.getIsMutable());
+ return success();
+ }
+};
+
+struct WasmGlobalWithGetGlobalInitConversion
+ : GlobalOpConverter<WasmGlobalWithGetGlobalInitConversion, GlobalGetOp> {
+ using GlobalOpConverter::GlobalOpConverter;
+ LogicalResult handleInitializer(GlobalOp globalOp,
+ ConversionPatternRewriter &rewriter,
+ GlobalGetOp constInit) const {
+ auto globalReplacement = rewriter.replaceOpWithNewOp<memref::GlobalOp>(
+ globalOp, globalOp.getSymNameAttr(), rewriter.getStringAttr("private"),
+ TypeAttr::get(MemRefType::get({1}, globalOp.getType())),
+ rewriter.getUnitAttr(),
+ /*constant*/ UnitAttr{},
+ /*alignment*/ IntegerAttr{});
+ globalReplacement.setConstant(!globalOp.getIsMutable());
+ auto loc = globalOp.getLoc();
+ auto initializerName = (globalOp.getSymName() + "::initializer").str();
+ auto globalInitializer = rewriter.create<func::FuncOp>(
+ loc, initializerName, FunctionType::get(getContext(), {}, {}));
+ globalInitializer->setAttr(rewriter.getStringAttr("initializer"),
+ rewriter.getUnitAttr());
+ auto *initializerBody = globalInitializer.addEntryBlock();
+ auto sip = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPointToStart(initializerBody);
+ auto srcGlobalPtr = rewriter.create<memref::GetGlobalOp>(
+ loc, MemRefType::get({1}, constInit.getType()), constInit.getGlobal());
+ auto destGlobalPtr = rewriter.create<memref::GetGlobalOp>(
+ loc, globalReplacement.getType(), globalReplacement.getSymName());
+ auto idx = rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
+ auto loadSrc =
+ rewriter.create<memref::LoadOp>(loc, srcGlobalPtr, ValueRange{idx});
+ rewriter.create<memref::StoreOp>(
+ loc, loadSrc.getResult(), destGlobalPtr.getResult(), ValueRange{idx});
+ rewriter.create<func::ReturnOp>(loc);
+ rewriter.restoreInsertionPoint(sip);
+ return success();
+ }
+};
+
+inline TypedAttr getInitializerAttr(Type t) {
+ assert(t.isIntOrFloat() &&
+ "This helper is intended to use with int and float types");
+ if (t.isInteger())
+ return IntegerAttr::get(t, 0);
+ if (t.isFloat())
+ return FloatAttr::get(t, 0.);
+ return TypedAttr{};
+}
+
+struct WasmLocalConversion : OpConversionPattern<LocalOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(LocalOp localOp, LocalOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto alloca = rewriter.replaceOpWithNewOp<memref::AllocaOp>(
+ localOp,
+ MemRefType::get({}, localOp.getResult().getType().getElementType()));
+ auto initializer = rewriter.create<arith::ConstantOp>(
+ localOp->getLoc(),
+ getInitializerAttr(localOp.getResult().getType().getElementType()));
+ rewriter.create<memref::StoreOp>(localOp->getLoc(), initializer.getResult(),
+ alloca.getResult());
+ return success();
+ }
+};
+
+struct WasmLocalGetConversion : OpConversionPattern<LocalGetOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(LocalGetOp localGetOp, LocalGetOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<memref::LoadOp>(
+ localGetOp, localGetOp.getResult().getType(), adaptor.getLocalVar(),
+ ValueRange{});
+ return success();
+ }
+};
+
+struct WasmLocalSetConversion : OpConversionPattern<LocalSetOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(LocalSetOp localSetOp, LocalSetOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<memref::StoreOp>(
+ localSetOp, adaptor.getValue(), adaptor.getLocalVar(), ValueRange{});
+ return success();
+ }
+};
+
+struct WasmLocalTeeConversion : OpConversionPattern<LocalTeeOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(LocalTeeOp localTeeOp, LocalTeeOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.create<memref::StoreOp>(localTeeOp->getLoc(), adaptor.getValue(),
+ adaptor.getLocalVar());
+ rewriter.replaceOp(localTeeOp, adaptor.getValue());
+ return success();
+ }
+};
+
+struct WasmReturnOpConversion : OpConversionPattern<ReturnOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ReturnOp returnOp, ReturnOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<func::ReturnOp>(returnOp,
+ adaptor.getOperands());
+ return success();
+ }
+};
+
+struct RaiseWasmMLIRPass : public impl::RaiseWasmMLIRBase<RaiseWasmMLIRPass> {
+ void runOnOperation() override {
+ ConversionTarget target{getContext()};
+ target.addIllegalDialect<WasmSSADialect>();
+ target.addLegalDialect<arith::ArithDialect, BuiltinDialect,
+ cf::ControlFlowDial...
[truncated]
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
f900cc4 to
1992924
Compare
arithmetic ops -- Co-authored-by: Luc Forget <[email protected]> Co-authored-by: Jessica Paquette <[email protected]>
1992924 to
86a772d
Compare
There is either a naming or description problem here: "raising" and "lowering" are opposite terms to me, can you clarify which one is correct here? (Or are you raising some things and lowering others?) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just skimmed through it, that seems fine to me overall.
I'm not particularly attached to that name, it used to be |
…obalOp to retrieve its terminator operation
fc3fee2 to
4577705
Compare
d6e6e9f to
45e31c0
Compare
45e31c0 to
04cb1df
Compare
This is following #154674 and still related to https://discourse.llvm.org/t/rfc-mlir-dialect-for-webassembly/86758.
This PR introduces the RaiseWasmMLIRPass. This pass lowers WasmSSA MLIR to other dialects of the LLVM ecosystem (namely: arith, math, cf and memref).
This is the first PR of a series of 2 or 3 to introduce the lowering, as an introduction it brings support for function calls, local and global variables and handling of arithmetic operations. As explained in the RFC, most WasmSSA operations have been made to stay close to other dialects' semantics so that conversion is trivialized.