Skip to content

Commit cbacf1c

Browse files
authored
Optimize contigious layout arrays handling (#150)
1 parent 85bb257 commit cbacf1c

File tree

3 files changed

+151
-11
lines changed

3 files changed

+151
-11
lines changed

numba_dpcomp/numba_dpcomp/mlir/tests/test_numpy.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,26 @@ def py_func(a):
814814

815815
assert_equal(py_func(arr), jit_func(arr))
816816

817+
def test_contigious_layout_opt():
818+
def py_func(a):
819+
return a[0,1]
820+
821+
jit_func = njit(py_func)
822+
823+
a = np.array([[1,2],[3,4]])
824+
b = a.T
825+
826+
with print_pass_ir([],['MakeStridedLayoutPass']):
827+
assert_equal(py_func(a), jit_func(a))
828+
ir = get_print_buffer()
829+
assert ir.count('affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>') == 0, ir
830+
831+
with print_pass_ir([],['MakeStridedLayoutPass']):
832+
assert_equal(py_func(b), jit_func(b))
833+
ir = get_print_buffer()
834+
assert ir.count('affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>') == 1, ir
835+
836+
817837
@parametrize_function_variants("a", [
818838
# 'np.array(1)', TODO zero rank arrays
819839
# 'np.array(2.5)',

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/lower_to_gpu.cpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include <mlir/Dialect/Arithmetic/Transforms/Passes.h>
3434
#include <mlir/Dialect/GPU/ParallelLoopMapper.h>
3535
#include <mlir/Dialect/GPU/Passes.h>
36+
#include <mlir/Dialect/GPU/Utils.h>
3637
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
3738
#include <mlir/Dialect/Math/IR/Math.h>
3839
#include <mlir/Dialect/MemRef/IR/MemRef.h>
@@ -2593,6 +2594,13 @@ void MarkGpuArraysInputs::runOnOperation() {
25932594
auto func = getOperation();
25942595
auto funcType = func.getType();
25952596

2597+
mlir::OpBuilder builder(&getContext());
2598+
auto attrStr = builder.getStringAttr(kGpuArgAttr);
2599+
if (func->hasAttr(attrStr)) {
2600+
markAllAnalysesPreserved();
2601+
return;
2602+
}
2603+
25962604
bool needAttr = false;
25972605
llvm::SmallVector<bool> result;
25982606
result.reserve(funcType.getNumInputs());
@@ -2606,10 +2614,8 @@ void MarkGpuArraysInputs::runOnOperation() {
26062614
for (auto type : (func.getType().getInputs()))
26072615
visitTypeRecursive(type, visitor);
26082616

2609-
if (needAttr) {
2610-
mlir::OpBuilder builder(&getContext());
2611-
func->setAttr(kGpuArgAttr, builder.getBoolArrayAttr(result));
2612-
}
2617+
if (needAttr)
2618+
func->setAttr(attrStr, builder.getBoolArrayAttr(result));
26132619

26142620
markAllAnalysesPreserved();
26152621
}
@@ -2921,6 +2927,31 @@ struct LowerGpuBuiltinsPass
29212927
: public plier::RewriteWrapperPass<LowerGpuBuiltinsPass, void, void,
29222928
LowerPlierCalls, LowerBuiltinCalls> {};
29232929

2930+
class GpuLaunchSinkOpsPass
2931+
: public mlir::PassWrapper<GpuLaunchSinkOpsPass,
2932+
mlir::OperationPass<void>> {
2933+
public:
2934+
void runOnOperation() override {
2935+
using namespace mlir;
2936+
2937+
Operation *op = getOperation();
2938+
if (op->walk([](gpu::LaunchOp launch) {
2939+
auto isSinkingBeneficiary = [](mlir::Operation *op) -> bool {
2940+
return isa<arith::ConstantOp, ConstantOp, arith::SelectOp,
2941+
arith::CmpIOp>(op);
2942+
};
2943+
2944+
// Pull in instructions that can be sunk
2945+
if (failed(
2946+
sinkOperationsIntoLaunchOp(launch, isSinkingBeneficiary)))
2947+
return WalkResult::interrupt();
2948+
2949+
return WalkResult::advance();
2950+
}).wasInterrupted())
2951+
signalPassFailure();
2952+
}
2953+
};
2954+
29242955
static void commonOptPasses(mlir::OpPassManager &pm) {
29252956
pm.addPass(mlir::createCanonicalizerPass());
29262957
pm.addPass(mlir::createCSEPass());
@@ -2949,12 +2980,15 @@ static void populateLowerToGPUPipelineLow(mlir::OpPassManager &pm) {
29492980
funcPM.addPass(mlir::createCanonicalizerPass());
29502981
funcPM.addPass(std::make_unique<UnstrideMemrefsPass>());
29512982
funcPM.addPass(mlir::createLowerAffinePass());
2983+
2984+
// TODO: mlir::gpu::GPUModuleOp pass
2985+
pm.addNestedPass<mlir::FuncOp>(mlir::arith::createArithmeticExpandOpsPass());
29522986
commonOptPasses(funcPM);
29532987
funcPM.addPass(std::make_unique<KernelMemrefOpsMovementPass>());
2954-
2988+
funcPM.addPass(std::make_unique<GpuLaunchSinkOpsPass>());
2989+
pm.addPass(mlir::createGpuKernelOutliningPass());
29552990
pm.addPass(mlir::createSymbolDCEPass());
29562991

2957-
pm.addPass(mlir::createGpuKernelOutliningPass());
29582992
pm.addNestedPass<mlir::FuncOp>(std::make_unique<GPULowerDefaultLocalSize>());
29592993
pm.nest<mlir::gpu::GPUModuleOp>().addNestedPass<mlir::gpu::GPUFuncOp>(
29602994
mlir::arith::createArithmeticExpandOpsPass());

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/plier_to_linalg.cpp

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,6 +1200,9 @@ struct ReshapeChangeLayout
12001200
}
12011201
};
12021202

1203+
static constexpr llvm::StringLiteral
1204+
kContigiousArraysAttr("plier.contigious_arrays");
1205+
12031206
struct MakeStridedLayoutPass
12041207
: public mlir::PassWrapper<MakeStridedLayoutPass,
12051208
mlir::OperationPass<mlir::ModuleOp>> {
@@ -1210,23 +1213,48 @@ void MakeStridedLayoutPass::runOnOperation() {
12101213
auto context = &getContext();
12111214
auto mod = getOperation();
12121215

1216+
mlir::OpBuilder builder(mod);
1217+
auto loc = builder.getUnknownLoc();
1218+
auto attrStr = builder.getStringAttr(kContigiousArraysAttr);
1219+
1220+
llvm::SmallVector<bool> contigiousArrayArg;
1221+
1222+
auto isContigiousArrayArg = [&](unsigned i) {
1223+
if (contigiousArrayArg.empty())
1224+
return false;
1225+
1226+
assert(i < contigiousArrayArg.size());
1227+
return contigiousArrayArg[i];
1228+
};
1229+
12131230
llvm::SmallVector<mlir::Type> newArgTypes;
12141231
llvm::SmallVector<mlir::Type> newResTypes;
12151232
llvm::SmallVector<mlir::Value> newOperands;
12161233
for (auto func : mod.getOps<mlir::FuncOp>()) {
1217-
mlir::OpBuilder builder(func.body());
1218-
auto loc = builder.getUnknownLoc();
1234+
auto contAttr = func->getAttr(attrStr).dyn_cast_or_null<mlir::ArrayAttr>();
1235+
if (contAttr) {
1236+
auto contAttrRange = contAttr.getAsValueRange<mlir::BoolAttr>();
1237+
contigiousArrayArg.assign(contAttrRange.begin(), contAttrRange.end());
1238+
} else {
1239+
contigiousArrayArg.clear();
1240+
}
1241+
12191242
auto funcType = func.getType();
12201243
auto argTypes = funcType.getInputs();
12211244
auto resTypes = funcType.getResults();
12221245
newArgTypes.assign(argTypes.begin(), argTypes.end());
12231246
newResTypes.assign(resTypes.begin(), resTypes.end());
1224-
bool hasBody = !func.getBody().empty();
1247+
auto &body = func.getBody();
1248+
bool hasBody = !body.empty();
1249+
if (hasBody)
1250+
builder.setInsertionPointToStart(&body.front());
1251+
12251252
for (auto it : llvm::enumerate(argTypes)) {
12261253
auto i = static_cast<unsigned>(it.index());
12271254
auto type = it.value();
12281255
auto memrefType = type.dyn_cast<mlir::MemRefType>();
1229-
if (!memrefType || !memrefType.getLayout().isIdentity())
1256+
if (!memrefType || isContigiousArrayArg(i) ||
1257+
!memrefType.getLayout().isIdentity())
12301258
continue;
12311259

12321260
auto rank = static_cast<unsigned>(memrefType.getRank());
@@ -1244,7 +1272,7 @@ void MakeStridedLayoutPass::runOnOperation() {
12441272
newArgTypes[i] = newMemrefType;
12451273

12461274
if (hasBody) {
1247-
auto arg = func.getBody().front().getArgument(i);
1275+
auto arg = body.front().getArgument(i);
12481276
arg.setType(newMemrefType);
12491277
auto dst =
12501278
builder.create<plier::ChangeLayoutOp>(loc, memrefType, arg);
@@ -2116,6 +2144,63 @@ void PostPlierToLinalgPass::runOnOperation() {
21162144
(void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
21172145
}
21182146

2147+
template <typename F>
2148+
static void visitTypeRecursive(mlir::Type type, F &&visitor) {
2149+
if (auto tupleType = type.dyn_cast<mlir::TupleType>()) {
2150+
for (auto t : tupleType.getTypes())
2151+
visitTypeRecursive(t, std::forward<F>(visitor));
2152+
} else {
2153+
visitor(type);
2154+
}
2155+
}
2156+
2157+
static bool isContigiousArray(mlir::Type type) {
2158+
auto pyType = type.dyn_cast<plier::PyType>();
2159+
if (!pyType)
2160+
return false;
2161+
2162+
auto name = pyType.getName();
2163+
auto desc = parseArrayDesc(name);
2164+
if (!desc)
2165+
return false;
2166+
2167+
return desc->layout == ArrayLayout::C;
2168+
}
2169+
2170+
struct MarkContigiousArraysPass
2171+
: public mlir::PassWrapper<MarkContigiousArraysPass,
2172+
mlir::OperationPass<mlir::FuncOp>> {
2173+
void runOnOperation() override {
2174+
auto func = getOperation();
2175+
auto funcType = func.getType();
2176+
2177+
mlir::OpBuilder builder(&getContext());
2178+
auto attrStr = builder.getStringAttr(kContigiousArraysAttr);
2179+
if (func->hasAttr(attrStr)) {
2180+
markAllAnalysesPreserved();
2181+
return;
2182+
}
2183+
2184+
bool needAttr = false;
2185+
llvm::SmallVector<bool> result;
2186+
result.reserve(funcType.getNumInputs());
2187+
2188+
auto visitor = [&](mlir::Type type) {
2189+
auto res = isContigiousArray(type);
2190+
result.emplace_back(res);
2191+
needAttr = needAttr || res;
2192+
};
2193+
2194+
for (auto type : (func.getType().getInputs()))
2195+
visitTypeRecursive(type, visitor);
2196+
2197+
if (needAttr)
2198+
func->setAttr(attrStr, builder.getBoolArrayAttr(result));
2199+
2200+
markAllAnalysesPreserved();
2201+
}
2202+
};
2203+
21192204
template <typename Op>
21202205
struct ConvertAlloc : public mlir::OpConversionPattern<Op> {
21212206
using mlir::OpConversionPattern<Op>::OpConversionPattern;
@@ -2601,6 +2686,7 @@ struct FixDeallocPlacementPass
26012686
void, FixDeallocPlacement> {};
26022687

26032688
static void populatePlierToLinalgGenPipeline(mlir::OpPassManager &pm) {
2689+
pm.addNestedPass<mlir::FuncOp>(std::make_unique<MarkContigiousArraysPass>());
26042690
pm.addPass(std::make_unique<PlierToLinalgPass>());
26052691
pm.addPass(mlir::createCanonicalizerPass());
26062692
pm.addPass(std::make_unique<NumpyCallsLoweringPass>());

0 commit comments

Comments
 (0)