Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,11 @@ def SPIRVWebGPUPreparePass : Pass<"spirv-webgpu-prepare", "spirv::ModuleOp"> {
"and replacing with supported ones";
}

def SPIRVReplicatedConstantCompositePass
: Pass<"spirv-convert-to-replicated-const-composite", "spirv::ModuleOp"> {
let summary = "Convert splat composite constants and spec constants to"
"corresponding replicated constant composite ops defined by"
"SPV_EXT_replicated_composites";
}

#endif // MLIR_DIALECT_SPIRV_TRANSFORMS_PASSES
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/SPIRV/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
set(LLVM_OPTIONAL_SOURCES
CanonicalizeGLPass.cpp
ConversionToReplicatedConstantCompositePass.cpp
DecorateCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
RewriteInsertsPass.cpp
Expand Down Expand Up @@ -30,6 +31,7 @@ add_mlir_dialect_library(MLIRSPIRVConversion

add_mlir_dialect_library(MLIRSPIRVTransforms
CanonicalizeGLPass.cpp
ConversionToReplicatedConstantCompositePass.cpp
DecorateCompositeTypeLayoutPass.cpp
LowerABIAttributesPass.cpp
RewriteInsertsPass.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
//===- ConversionToReplicatedConstantCompositePass.cpp --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert a splat composite spirv.Constant and
// spirv.SpecConstantComposite to spirv.EXT.ConstantCompositeReplicate and
// spirv.EXT.SpecConstantCompositeReplicate respectively.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace spirv {
#define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS
#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
} // namespace spirv
} // namespace mlir

using namespace mlir;

namespace {

Attribute getSplatAttribute(Attribute valueAttr, uint32_t &splatCount) {
Attribute attr;
if (auto denseAttr = dyn_cast<DenseElementsAttr>(valueAttr)) {
if (denseAttr.isSplat()) {
attr = denseAttr.getSplatValue<Attribute>();
splatCount = denseAttr.size();
}
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
if (std::adjacent_find(arrayAttr.begin(), arrayAttr.end(),
std::not_equal_to<>()) == arrayAttr.end()) {
attr = arrayAttr[0];
splatCount = arrayAttr.size();
}
}

if (attr) {
if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
if (isa<spirv::CompositeType>(typedAttr.getType()))
if (Attribute newAttr = getSplatAttribute(attr, splatCount))
attr = newAttr;
} else if (isa<ArrayAttr>(attr)) {
if (Attribute newAttr = getSplatAttribute(attr, splatCount))
attr = newAttr;
}
}

return attr;
}

} // namespace

namespace {
class ConversionToReplicatedConstantCompositePass
: public spirv::impl::SPIRVReplicatedConstantCompositePassBase<
ConversionToReplicatedConstantCompositePass> {
public:
void runOnOperation() override;
};

class ConstantOpConversion : public OpRewritePattern<spirv::ConstantOp> {
using OpRewritePattern<spirv::ConstantOp>::OpRewritePattern;

LogicalResult matchAndRewrite(spirv::ConstantOp op,
PatternRewriter &rewriter) const override {
auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
if (!compositeType)
return rewriter.notifyMatchFailure(op, "not a composite constant");

uint32_t splatCount = 0;
Attribute splatAttr = getSplatAttribute(op.getValue(), splatCount);
if (!splatAttr)
return rewriter.notifyMatchFailure(op, "composite is not splat");

if (splatCount == 1)
return rewriter.notifyMatchFailure(op,
"composite has only one constituent");

rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
op, op.getType(), splatAttr);

return success();
}
};

class SpecConstantCompositeOpConversion
: public OpRewritePattern<spirv::SpecConstantCompositeOp> {
using OpRewritePattern<spirv::SpecConstantCompositeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op,
PatternRewriter &rewriter) const override {
auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
if (!compositeType)
return rewriter.notifyMatchFailure(op, "not a composite constant");

ArrayAttr constituents = op.getConstituents();
if (constituents.size() == 1)
return rewriter.notifyMatchFailure(op,
"composite has only one consituent");

if (!(std::adjacent_find(constituents.begin(), constituents.end(),
std::not_equal_to<>()) == constituents.end()))
return rewriter.notifyMatchFailure(op, "composite is not splat");

auto splatConstituent =
dyn_cast<FlatSymbolRefAttr>(op.getConstituents()[0]);
if (!splatConstituent)
return rewriter.notifyMatchFailure(
op, "expected flat symbol reference for splat constituent");

rewriter.replaceOpWithNewOp<spirv::EXTSpecConstantCompositeReplicateOp>(
op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent);

return success();
}
};

void ConversionToReplicatedConstantCompositePass::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<ConstantOpConversion>(context);
patterns.add<SpecConstantCompositeOpConversion>(context);

if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
}

} // namespace
Loading