Skip to content

Commit 8438278

Browse files
authored
Ops cleanup (#194)
* Remove `DelOp`, `UndefOp` and `ExtractMemrefMetadataOp` from `gpu_runtime` dialect * Move `UndefOp` from `plier` to `plier_util` dialect * Cleanup headers in `gpu_runtimr_to_llvm.hpp`
1 parent 6c854b9 commit 8438278

File tree

9 files changed

+38
-123
lines changed

9 files changed

+38
-123
lines changed

mlir/include/mlir-extensions/Conversion/gpu_runtime_to_llvm.hpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,11 @@
1414

1515
#pragma once
1616

17-
#include "mlir-extensions/dialect/gpu_runtime/IR/gpu_runtime_ops.hpp"
18-
#include "mlir-extensions/transforms/func_utils.hpp"
17+
#include <memory>
1918

20-
#include <mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h>
21-
#include <mlir/Conversion/GPUCommon/GPUCommonPass.h>
22-
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
23-
#include <mlir/Conversion/LLVMCommon/Pattern.h>
24-
#include <mlir/Conversion/LLVMCommon/TypeConverter.h>
25-
#include <mlir/Dialect/GPU/Passes.h>
26-
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
27-
#include <mlir/Pass/PassManager.h>
28-
#include <mlir/Transforms/DialectConversion.h>
29-
#include <mlir/Transforms/Passes.h>
19+
namespace mlir {
20+
class Pass;
21+
}
3022

3123
namespace gpu_runtime {
3224

mlir/include/mlir-extensions/dialect/gpu_runtime/IR/GpuRuntimeOps.td

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,6 @@ def GpuRuntime_OpaqueType
4040
"opaque_type">,
4141
BuildableType<"$_builder.getType<::gpu_runtime::OpaqueType>()"> {}
4242

43-
def DelOp : GpuRuntime_Op<"del", []> { let arguments = (ins AnyType : $value); }
44-
45-
def UndefOp : GpuRuntime_Op<"undef", [NoSideEffect]> {
46-
let results = (outs AnyType);
47-
}
48-
49-
def ExtractMemrefMetadataOp
50-
: GpuRuntime_Op<"extract_memref_metadata", [NoSideEffect]> {
51-
let arguments = (ins AnyMemRef : $source, IndexAttr : $dimIndex);
52-
53-
let results = (outs Index : $result);
54-
let hasFolder = 1;
55-
56-
let builders = [
57-
OpBuilder<(ins "::mlir::Value" : $src,
58-
"int64_t" : $dim)>,
59-
OpBuilder<(ins "::mlir::Value" : $src)>
60-
];
61-
}
62-
6343
def CreateGpuStreamOp : GpuRuntime_Op<"create_gpu_stream", [NoSideEffect]> {
6444
let results = (outs GpuRuntime_OpaqueType : $result);
6545

mlir/include/mlir-extensions/dialect/plier/PlierOps.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,6 @@ def ExhaustIterOp : Plier_Op<"exhaust_iter", [NoSideEffect]> {
205205
: $count)>];
206206
}
207207

208-
def UndefOp : Plier_Op<"undef", [NoSideEffect]> {
209-
let results = (outs AnyType);
210-
}
211-
212208
def BuildSliceOp : Plier_Op<"build_slice", [NoSideEffect]> {
213209
let arguments = (ins AnyType : $begin, AnyType : $end, AnyType : $step);
214210

mlir/include/mlir-extensions/dialect/plier_util/PlierUtilOps.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,12 @@ def ParallelOp : PlierUtil_Op<"parallel", [
7373
"nullptr">)>];
7474

7575
let extraClassDeclaration = [{
76-
unsigned getNumLoops() { return steps().size();
76+
unsigned getNumLoops() { return steps().size(); }
77+
}];
7778
}
78-
}];
79+
80+
def UndefOp : PlierUtil_Op<"undef", [NoSideEffect]> {
81+
let results = (outs AnyType);
7982
}
8083

8184
def YieldOp : PlierUtil_Op<"yield", [

mlir/lib/Conversion/gpu_runtime_to_llvm.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,28 @@
1414

1515
#include "mlir-extensions/Conversion/gpu_runtime_to_llvm.hpp"
1616

17+
#include "mlir-extensions/dialect/gpu_runtime/IR/gpu_runtime_ops.hpp"
18+
#include "mlir-extensions/dialect/plier_util/dialect.hpp"
19+
#include "mlir-extensions/transforms/func_utils.hpp"
20+
21+
#include <mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h>
22+
#include <mlir/Conversion/GPUCommon/GPUCommonPass.h>
23+
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
24+
#include <mlir/Conversion/LLVMCommon/Pattern.h>
25+
#include <mlir/Conversion/LLVMCommon/TypeConverter.h>
26+
#include <mlir/Dialect/GPU/Passes.h>
27+
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
28+
#include <mlir/Pass/PassManager.h>
29+
#include <mlir/Transforms/DialectConversion.h>
30+
#include <mlir/Transforms/Passes.h>
31+
1732
static const char *kGpuAllocShared = "gpu.alloc_shared";
1833

19-
struct LowerUndef : public mlir::ConvertOpToLLVMPattern<gpu_runtime::UndefOp> {
34+
struct LowerUndef : public mlir::ConvertOpToLLVMPattern<plier::UndefOp> {
2035
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
2136

2237
mlir::LogicalResult
23-
matchAndRewrite(gpu_runtime::UndefOp op,
24-
gpu_runtime::UndefOp::Adaptor /*adaptor*/,
38+
matchAndRewrite(plier::UndefOp op, plier::UndefOp::Adaptor /*adaptor*/,
2539
mlir::ConversionPatternRewriter &rewriter) const override {
2640
auto converter = getTypeConverter();
2741
auto type = converter->convertType(op.getType());

mlir/lib/Conversion/gpu_to_gpu_runtime.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir-extensions/Conversion/gpu_to_gpu_runtime.hpp"
1616

1717
#include "mlir-extensions/dialect/gpu_runtime/IR/gpu_runtime_ops.hpp"
18+
#include "mlir-extensions/dialect/plier_util/dialect.hpp"
1819

1920
#include <mlir/Analysis/BufferViewFlowAnalysis.h>
2021
#include <mlir/Conversion/AffineToStandard/AffineToStandard.h>
@@ -411,14 +412,13 @@ static mlir::Value getFlatIndex(mlir::OpBuilder &builder, mlir::Location loc,
411412
auto numSymbols = affineMap.getNumSymbols();
412413
if (numSymbols > 0) {
413414
applyOperands.emplace_back(
414-
builder.createOrFold<gpu_runtime::ExtractMemrefMetadataOp>(loc,
415-
memref));
415+
builder.createOrFold<plier::ExtractMemrefMetadataOp>(loc, memref));
416416
--numSymbols;
417417
assert(numSymbols <= rank);
418418
for (auto i : llvm::seq(0u, numSymbols)) {
419419
applyOperands.emplace_back(
420-
builder.createOrFold<gpu_runtime::ExtractMemrefMetadataOp>(
421-
loc, memref, i));
420+
builder.createOrFold<plier::ExtractMemrefMetadataOp>(loc, memref,
421+
i));
422422
}
423423
}
424424
}
@@ -453,7 +453,7 @@ static mlir::Value getFlatMemref(mlir::OpBuilder &builder, mlir::Location loc,
453453
setInsertionPointToStart(builder, memref);
454454
mlir::OpFoldResult offset = builder.getIndexAttr(0);
455455
mlir::OpFoldResult size =
456-
builder.createOrFold<gpu_runtime::UndefOp>(loc, builder.getIndexType());
456+
builder.createOrFold<plier::UndefOp>(loc, builder.getIndexType());
457457
mlir::OpFoldResult stride = builder.getIndexAttr(1);
458458
return builder.createOrFold<mlir::memref::ReinterpretCastOp>(
459459
loc, resultType, memref, offset, size, stride);
@@ -539,7 +539,7 @@ struct FlattenSubview : public mlir::OpRewritePattern<mlir::memref::SubViewOp> {
539539
auto loc = op.getLoc();
540540
mlir::OpFoldResult flatIndex = getFlatIndex(rewriter, loc, memref, offsets);
541541
mlir::OpFoldResult flatSize =
542-
rewriter.create<gpu_runtime::UndefOp>(loc, rewriter.getIndexType())
542+
rewriter.create<plier::UndefOp>(loc, rewriter.getIndexType())
543543
.getResult();
544544
mlir::OpFoldResult flatStride = rewriter.getIndexAttr(1);
545545
auto flatMemref = getFlatMemref(rewriter, loc, memref);
@@ -572,7 +572,7 @@ struct FlattenSubview : public mlir::OpRewritePattern<mlir::memref::SubViewOp> {
572572
auto origStride = [&]() {
573573
mlir::OpBuilder::InsertionGuard g(rewriter);
574574
setInsertionPointToStart(rewriter, memref);
575-
return rewriter.createOrFold<gpu_runtime::ExtractMemrefMetadataOp>(
575+
return rewriter.createOrFold<plier::ExtractMemrefMetadataOp>(
576576
loc, memref, i);
577577
}();
578578
auto newStride = rewriter.createOrFold<mlir::arith::MulIOp>(

mlir/lib/dialect/gpu_runtime/IR/gpu_runtime_ops.cpp

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -90,81 +90,6 @@ OpaqueType OpaqueType::get(mlir::MLIRContext *context) {
9090
return Base::get(context);
9191
}
9292

93-
void ExtractMemrefMetadataOp::build(::mlir::OpBuilder &odsBuilder,
94-
::mlir::OperationState &odsState,
95-
::mlir::Value src, int64_t dim) {
96-
assert(dim >= 0 && dim < src.getType().cast<mlir::MemRefType>().getRank());
97-
ExtractMemrefMetadataOp::build(odsBuilder, odsState,
98-
odsBuilder.getIndexType(), src,
99-
odsBuilder.getIndexAttr(dim));
100-
}
101-
102-
void ExtractMemrefMetadataOp::build(::mlir::OpBuilder &odsBuilder,
103-
::mlir::OperationState &odsState,
104-
::mlir::Value src) {
105-
ExtractMemrefMetadataOp::build(odsBuilder, odsState,
106-
odsBuilder.getIndexType(), src,
107-
odsBuilder.getIndexAttr(-1));
108-
}
109-
110-
mlir::OpFoldResult
111-
ExtractMemrefMetadataOp::fold(llvm::ArrayRef<mlir::Attribute> /*operands*/) {
112-
auto idx = dimIndex().getSExtValue();
113-
assert(idx >= -1);
114-
auto src = source();
115-
116-
int64_t offset;
117-
llvm::SmallVector<int64_t> strides;
118-
if (mlir::succeeded(mlir::getStridesAndOffset(
119-
src.getType().cast<mlir::MemRefType>(), strides, offset))) {
120-
mlir::Builder builder(getContext());
121-
if (idx == -1 && !mlir::ShapedType::isDynamicStrideOrOffset(offset)) {
122-
return builder.getIndexAttr(offset);
123-
} else if (idx >= 0 && idx < static_cast<int64_t>(strides.size()) &&
124-
!mlir::ShapedType::isDynamicStrideOrOffset(
125-
strides[static_cast<unsigned>(idx)])) {
126-
return builder.getIndexAttr(strides[static_cast<unsigned>(idx)]);
127-
}
128-
}
129-
130-
if (auto reintr = src.getDefiningOp<mlir::memref::ReinterpretCastOp>()) {
131-
if (idx == -1) {
132-
auto offsets = reintr.getMixedOffsets();
133-
if (offsets.size() == 1)
134-
return offsets.front();
135-
136-
return nullptr;
137-
}
138-
139-
auto strides = reintr.getMixedStrides();
140-
if (static_cast<unsigned>(idx) < strides.size())
141-
return strides[static_cast<unsigned>(idx)];
142-
143-
return nullptr;
144-
}
145-
146-
if (auto cast = src.getDefiningOp<mlir::memref::CastOp>()) {
147-
auto newSrc = cast.source();
148-
sourceMutable().assign(newSrc);
149-
return getResult();
150-
}
151-
152-
if (auto cast = src.getDefiningOp<mlir::memref::CastOp>()) {
153-
auto castSrc = cast.source();
154-
auto castSrcType = castSrc.getType().cast<mlir::ShapedType>();
155-
auto srcType = src.getType().cast<mlir::ShapedType>();
156-
if (castSrcType.hasRank() && srcType.hasRank() &&
157-
castSrcType.getRank() == srcType.getRank()) {
158-
sourceMutable().assign(castSrc);
159-
return getResult();
160-
}
161-
162-
return nullptr;
163-
}
164-
165-
return nullptr;
166-
}
167-
16893
namespace {
16994
template <typename Op, typename DelOp>
17095
struct RemoveUnusedOp : public mlir::OpRewritePattern<Op> {

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/lowering.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <llvm/Support/Debug.h>
3636

3737
#include "mlir-extensions/dialect/plier/dialect.hpp"
38+
#include "mlir-extensions/dialect/plier_util/dialect.hpp"
3839

3940
#include "mlir-extensions/compiler/compiler.hpp"
4041
#include "mlir-extensions/compiler/pipeline_registry.hpp"
@@ -153,6 +154,7 @@ struct PlierLowerer final {
153154
PlierLowerer(mlir::MLIRContext &context) : ctx(context), builder(&ctx) {
154155
ctx.loadDialect<mlir::func::FuncDialect>();
155156
ctx.loadDialect<plier::PlierDialect>();
157+
ctx.loadDialect<plier::PlierUtilDialect>();
156158
}
157159

158160
mlir::FuncOp lower(const py::object &compilationContext, mlir::ModuleOp mod,

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/plier_to_scf.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <mlir/Transforms/Passes.h>
2525

2626
#include "mlir-extensions/dialect/plier/dialect.hpp"
27+
#include "mlir-extensions/dialect/plier_util/dialect.hpp"
2728
#include "mlir-extensions/transforms/arg_lowering.hpp"
2829
#include "mlir-extensions/transforms/common_opts.hpp"
2930

@@ -599,6 +600,8 @@ struct PlierToScfPass
599600
virtual void
600601
getDependentDialects(mlir::DialectRegistry &registry) const override {
601602
registry.insert<mlir::scf::SCFDialect>();
603+
registry.insert<plier::PlierDialect>();
604+
registry.insert<plier::PlierUtilDialect>();
602605
}
603606

604607
void runOnOperation() override;

0 commit comments

Comments
 (0)