From 23983cd4bec230c38da4fd7d7c4d25ce2bb2839e Mon Sep 17 00:00:00 2001 From: matth2k Date: Mon, 22 Jan 2024 20:17:26 -0500 Subject: [PATCH] new FoldBitWidth pass to fix cornell-zhang/allo#121 --- include/hcl/Transforms/Passes.h | 1 + include/hcl/Transforms/Passes.td | 5 + lib/Transforms/CMakeLists.txt | 1 + lib/Transforms/FoldBitWidth.cpp | 321 +++++++++++++++++++++++++++++++ tools/hcl-opt/hcl-opt.cpp | 10 + 5 files changed, 338 insertions(+) create mode 100644 lib/Transforms/FoldBitWidth.cpp diff --git a/include/hcl/Transforms/Passes.h b/include/hcl/Transforms/Passes.h index be8f9426..e1d59739 100644 --- a/include/hcl/Transforms/Passes.h +++ b/include/hcl/Transforms/Passes.h @@ -17,6 +17,7 @@ std::unique_ptr> createAnyWidthIntegerPass(); std::unique_ptr> createMoveReturnToInputPass(); std::unique_ptr> createLegalizeCastPass(); std::unique_ptr> createRemoveStrideMapPass(); +std::unique_ptr> createFoldBitWidthPass(); std::unique_ptr> createMemRefDCEPass(); std::unique_ptr> createDataPlacementPass(); std::unique_ptr> createTransformInterpreterPass(); diff --git a/include/hcl/Transforms/Passes.td b/include/hcl/Transforms/Passes.td index 03788e6c..15d3f3bc 100644 --- a/include/hcl/Transforms/Passes.td +++ b/include/hcl/Transforms/Passes.td @@ -38,6 +38,11 @@ def RemoveStrideMap : Pass<"remove-stride-map", "ModuleOp"> { let constructor = "mlir::hcl::createRemoveStrideMapPass()"; } +def FoldBitWidth : Pass<"fold-bit-width", "ModuleOp"> { + let summary = "Remove ext and trunc operations surrounding wrap-around ops"; + let constructor = "mlir::hcl::createFoldBitWidthPass()"; +} + def MemRefDCE : Pass<"memref-dce", "ModuleOp"> { let summary = "Remove MemRefs that are never loaded from"; let constructor = "mlir::hcl::createMemRefDCEPass()"; diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt index f2b87b1b..02114f1b 100644 --- a/lib/Transforms/CMakeLists.txt +++ b/lib/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_library(MLIRHCLPasses Passes.cpp LegalizeCast.cpp RemoveStrideMap.cpp + FoldBitWidth.cpp MemRefDCE.cpp DataPlacement.cpp TransformInterpreter.cpp diff --git a/lib/Transforms/FoldBitWidth.cpp b/lib/Transforms/FoldBitWidth.cpp new file mode 100644 index 00000000..2cfa390f --- /dev/null +++ b/lib/Transforms/FoldBitWidth.cpp @@ -0,0 +1,321 @@ +/* + * Copyright HeteroCL authors. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "PassDetail.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/RegionUtils.h" +#include +#include +#include +#include + +#include "hcl/Transforms/Passes.h" + +using namespace mlir; +using namespace hcl; + +namespace mlir { +namespace hcl { +template struct FoldWidth : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static std::optional operandIfExtended(Value operand) { + auto *definingOp = operand.getDefiningOp(); + if (!definingOp) + return std::nullopt; + + if (!isa(operand.getType())) + return std::nullopt; + + if (auto extOp = dyn_cast(*definingOp)) + return cast(extOp->getOperand(0).getType()); + if (auto extOp = dyn_cast(*definingOp)) + return cast(extOp->getOperand(0).getType()); + + return std::nullopt; + } + + static std::optional + valIfTruncated(TypedValue val) { + if (!val.hasOneUse()) + return std::nullopt; + auto *op = *val.getUsers().begin(); + if (auto trunc = dyn_cast(*op)) + if (auto truncType = dyn_cast(trunc.getType())) + return truncType; + + return std::nullopt; + } + + static bool opIsLegal(OpTy op) { + if (op->getNumResults() != 1) + return true; + if (op->getNumOperands() <= 0) + return true; + if (!isa(op->getResultTypes().front())) + return true; + + auto outType = + valIfTruncated(cast>(op->getResult(0))); + if (!outType.has_value()) + return true; + + auto operandType = operandIfExtended(op->getOperand(0)); + if (!operandType.has_value() || operandType != outType) + return true; + + // Extension and trunc should be opt away + SmallVector operands; + for (auto operand : op->getOperands()) { + auto oW = operandIfExtended(operand); + if (oW != operandType) + return true; + } + return false; + } + + LogicalResult + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (opIsLegal(op)) + return failure(); + + auto outType = + valIfTruncated(cast>(op->getResult(0))); + + // Extension and trunc should be opt away + SmallVector operands; + for (auto operand : op->getOperands()) + operands.push_back(operand.getDefiningOp()->getOperand(0)); + + SmallVector resultTypes = {*outType}; + auto newOp = rewriter.create(op.getLoc(), resultTypes, operands); + auto trunc = *op->getUsers().begin(); + trunc->getResult(0).replaceAllUsesWith(newOp->getResult(0)); + rewriter.eraseOp(trunc); + rewriter.eraseOp(op); + + return success(); + } +}; + +template struct FoldLinalgWidth : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + static unsigned getIndex(mlir::Block::OpListType &opList, Operation *item) { + + for (auto op : enumerate(opList)) + if (&op.value() == item) + return op.index(); + assert(false && "Op not in Op list"); + } + + static SmallVector getUsersSorted(Value memref) { + SmallVector users(memref.getUsers().begin(), + memref.getUsers().end()); + + std::sort(users.begin(), users.end(), + [&memref](Operation *a, Operation *b) { + return getIndex(memref.getParentBlock()->getOperations(), a) < + getIndex(memref.getParentBlock()->getOperations(), b); + }); + + return users; + } + + static std::optional, linalg::GenericOp>> + operandIfExtended(TypedValue memref) { + if (memref.getUsers().empty()) + return std::nullopt; + + auto users = getUsersSorted(memref); + // If a buffer is used for the sake of type-conversion it should only have 2 + // uses. + if (users.size() != 2) + return std::nullopt; + + // If this is an extended operand, the first use should be a GenericOp that + // extends + if (!isa(users.front())) + return std::nullopt; + + auto genericOp = cast(users.front()); + + // Check that the Generic Op is used to extend with memref as an output + if (genericOp.getOutputs().front() != memref || + genericOp.getBody()->getOperations().size() != 2 || + genericOp.getInputs().size() != 1) + return std::nullopt; + + auto &operation = genericOp.getBody()->front(); + if (!isa(operation) && !isa(operation)) + return std::nullopt; + + // Return the memory buffer that is being extended and the GenericOp too + return std::pair( + cast>(genericOp.getInputs().front()), genericOp); + } + + static std::optional, linalg::GenericOp>> + valIfTruncated(TypedValue memref) { + if (memref.getUsers().empty()) + return std::nullopt; + + auto users = getUsersSorted(memref); + // If a buffer is used for the sake of type-conversion it should only have 2 + // uses. + if (users.size() != 2) + return std::nullopt; + + // If this is an truncated operand, the last use should be a GenericOp that + // truncates + if (!isa(users.back())) + return std::nullopt; + + auto genericOp = cast(users.back()); + + // Check that the Generic Op is used to truncate the memref input + if (genericOp.getInputs().front() != memref || + genericOp.getBody()->getOperations().size() != 2 || + genericOp.getOutputs().size() != 1) + return std::nullopt; + + auto &operation = genericOp.getBody()->front(); + if (!isa(operation)) + return std::nullopt; + + // Return the memory buffer that is being truncated and the GenericOp too + return std::pair( + cast>(genericOp.getOutputs().front()), + genericOp); + } + + // Test if we should apply this pattern or not + static bool opIsLegal(OpTy op) { + + // Should be a binary operation + if (op.getInputs().size() != 2) + return true; + if (op.getOutputs().size() != 1) + return true; + + auto outType = + valIfTruncated(cast>(op.getOutputs().front())); + if (!outType.has_value()) + return true; + + auto inputs = op.getInputs(); + auto firstOperand = + operandIfExtended(cast>(inputs[0])); + if (!firstOperand.has_value() || + firstOperand->first.getType() != outType->first.getType()) + return true; + + auto secondOperand = + operandIfExtended(cast>(inputs[1])); + if (!secondOperand.has_value() || + firstOperand->first.getType() != secondOperand->first.getType()) + return true; + + // At this point, we know all memref types are equivalent so the pattern + // should be applied + return false; + } + + LogicalResult + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (opIsLegal(op)) + return failure(); + + auto outType = + valIfTruncated(cast>(op.getOutputs().front())); + auto inputs = op.getInputs(); + auto firstOperand = + operandIfExtended(cast>(inputs[0])); + auto secondOperand = + operandIfExtended(cast>(inputs[1])); + + // Extension and trunc should be opt away + SmallVector operands({firstOperand->first, secondOperand->first}); + + SmallVector results({outType->first}); + + // Create the new linalg operation, and move the output memory buffer up in + // the instructions so that it dominates + auto newop = rewriter.create(op->getLoc(), operands, results); + newop.getOutputs().front().getDefiningOp()->moveBefore(newop); + + // It is safe to delete these operations, because we force that each + // memory buffer only has 2 uses + rewriter.eraseOp(outType->second); + rewriter.eraseOp(firstOperand->second); + rewriter.eraseOp(secondOperand->second); + rewriter.eraseOp(op); + assert(opIsLegal(newop)); + + return success(); + } +}; +} // namespace hcl +} // namespace mlir + +namespace { +struct HCLFoldBitWidthTransformation + : public FoldBitWidthBase { + void runOnOperation() override { + auto *context = &getContext(); + RewritePatternSet patterns(context); + ConversionTarget target(*context); + + // Patterns for scalar wraparound operations + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + + // Targets for scalar wraparound operations + target.addDynamicallyLegalOp( + FoldWidth::opIsLegal); + target.addDynamicallyLegalOp( + FoldWidth::opIsLegal); + target.addDynamicallyLegalOp( + FoldWidth::opIsLegal); + + // Patterns for linalg wraparound operations + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + + // Targets for linalg wraparound operations + target.addDynamicallyLegalOp( + FoldLinalgWidth::opIsLegal); + target.addDynamicallyLegalOp( + FoldLinalgWidth::opIsLegal); + target.addDynamicallyLegalOp( + FoldLinalgWidth::opIsLegal); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + OpBuilder builder(getOperation()); + IRRewriter rewriter(builder); + (void)runRegionDCE(rewriter, getOperation()->getRegions()); + } +}; +} // namespace + +namespace mlir { +namespace hcl { +std::unique_ptr> createFoldBitWidthPass() { + return std::make_unique(); +} +} // namespace hcl +} // namespace mlir diff --git a/tools/hcl-opt/hcl-opt.cpp b/tools/hcl-opt/hcl-opt.cpp index 230ff8dc..fb52c577 100644 --- a/tools/hcl-opt/hcl-opt.cpp +++ b/tools/hcl-opt/hcl-opt.cpp @@ -97,6 +97,12 @@ static llvm::cl::opt removeStrideMap("remove-stride-map", llvm::cl::desc("Remove stride map"), llvm::cl::init(false)); +static llvm::cl::opt foldBitWidth( + "fold-bit-width", + llvm::cl::desc( + "Remove ext and trunc operations surrounding wrap-around ops"), + llvm::cl::init(false)); + static llvm::cl::opt lowerPrintOps("lower-print-ops", llvm::cl::desc("Lower print ops"), llvm::cl::init(false)); @@ -319,6 +325,10 @@ int main(int argc, char **argv) { pm.addPass(mlir::hcl::createRemoveStrideMapPass()); } + if (foldBitWidth) { + pm.addPass(mlir::hcl::createFoldBitWidthPass()); + } + if (bufferization) { pm.addPass(mlir::bufferization::createOneShotBufferizePass()); }