1818#include " mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1919#include " mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
2020#include " mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21+ #include " mlir/IR/Attributes.h"
2122#include " mlir/IR/BuiltinAttributes.h"
2223#include " mlir/IR/BuiltinTypes.h"
2324#include " mlir/Interfaces/FunctionInterfaces.h"
2425#include " mlir/Transforms/DialectConversion.h"
26+ #include " llvm/ADT/SmallVectorExtras.h"
2527#include " llvm/ADT/StringExtras.h"
2628#include " llvm/Support/Debug.h"
2729
@@ -54,7 +56,8 @@ using namespace mlir;
5456 MAP_FN(spirv::StorageClass::PushConstant, 7 ) \
5557 MAP_FN(spirv::StorageClass::UniformConstant, 8 ) \
5658 MAP_FN(spirv::StorageClass::Input, 9 ) \
57- MAP_FN(spirv::StorageClass::Output, 10 )
59+ MAP_FN(spirv::StorageClass::Output, 10 ) \
60+ MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 11 )
5861
5962std::optional<spirv::StorageClass>
6063spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) {
@@ -185,13 +188,10 @@ spirv::MemorySpaceToStorageClassConverter::MemorySpaceToStorageClassConverter(
185188 });
186189
187190 addConversion ([this ](FunctionType type) {
188- SmallVector<Type> inputs, results;
189- inputs.reserve (type.getNumInputs ());
190- results.reserve (type.getNumResults ());
191- for (Type input : type.getInputs ())
192- inputs.push_back (convertType (input));
193- for (Type result : type.getResults ())
194- results.push_back (convertType (result));
191+ auto inputs = llvm::map_to_vector (
192+ type.getInputs (), [this ](Type ty) { return convertType (ty); });
193+ auto results = llvm::map_to_vector (
194+ type.getResults (), [this ](Type ty) { return convertType (ty); });
195195 return FunctionType::get (type.getContext (), inputs, results);
196196 });
197197}
@@ -250,49 +250,54 @@ spirv::getMemorySpaceToStorageClassTarget(MLIRContext &context) {
250250namespace {
251251// / Converts any op that has operands/results/attributes with numeric MemRef
252252// / memory spaces.
253- struct MapMemRefStoragePattern final : public ConversionPattern {
253+ struct MapMemRefStoragePattern final : ConversionPattern {
254254 MapMemRefStoragePattern (MLIRContext *context, TypeConverter &converter)
255255 : ConversionPattern(converter, MatchAnyOpTypeTag(), 1 , context) {}
256256
257257 LogicalResult
258258 matchAndRewrite (Operation *op, ArrayRef<Value> operands,
259- ConversionPatternRewriter &rewriter) const override ;
260- };
261- } // namespace
262-
263- LogicalResult MapMemRefStoragePattern::matchAndRewrite (
264- Operation *op, ArrayRef<Value> operands,
265- ConversionPatternRewriter &rewriter) const {
266- llvm::SmallVector<NamedAttribute, 4 > newAttrs;
267- newAttrs.reserve (op->getAttrs ().size ());
268- for (auto attr : op->getAttrs ()) {
269- if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue ())) {
270- auto newAttr = getTypeConverter ()->convertType (typeAttr.getValue ());
271- newAttrs.emplace_back (attr.getName (), TypeAttr::get (newAttr));
272- } else {
273- newAttrs.push_back (attr);
259+ ConversionPatternRewriter &rewriter) const override {
260+ llvm::SmallVector<NamedAttribute> newAttrs;
261+ newAttrs.reserve (op->getAttrs ().size ());
262+ for (NamedAttribute attr : op->getAttrs ()) {
263+ if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue ())) {
264+ Type newAttr = getTypeConverter ()->convertType (typeAttr.getValue ());
265+ if (!newAttr) {
266+ return rewriter.notifyMatchFailure (
267+ op, " type attribute conversion failed" );
268+ }
269+ newAttrs.emplace_back (attr.getName (), TypeAttr::get (newAttr));
270+ } else {
271+ newAttrs.push_back (attr);
272+ }
274273 }
275- }
276274
277- llvm::SmallVector<Type, 4 > newResults;
278- (void )getTypeConverter ()->convertTypes (op->getResultTypes (), newResults);
279-
280- OperationState state (op->getLoc (), op->getName ().getStringRef (), operands,
281- newResults, newAttrs, op->getSuccessors ());
275+ llvm::SmallVector<Type, 4 > newResults;
276+ if (failed (
277+ getTypeConverter ()->convertTypes (op->getResultTypes (), newResults)))
278+ return rewriter.notifyMatchFailure (op, " result type conversion failed" );
279+
280+ OperationState state (op->getLoc (), op->getName ().getStringRef (), operands,
281+ newResults, newAttrs, op->getSuccessors ());
282+
283+ for (Region ®ion : op->getRegions ()) {
284+ Region *newRegion = state.addRegion ();
285+ rewriter.inlineRegionBefore (region, *newRegion, newRegion->begin ());
286+ TypeConverter::SignatureConversion result (newRegion->getNumArguments ());
287+ if (failed (getTypeConverter ()->convertSignatureArgs (
288+ newRegion->getArgumentTypes (), result))) {
289+ return rewriter.notifyMatchFailure (
290+ op, " signature argument type conversion failed" );
291+ }
292+ rewriter.applySignatureConversion (newRegion, result);
293+ }
282294
283- for (Region ®ion : op->getRegions ()) {
284- Region *newRegion = state.addRegion ();
285- rewriter.inlineRegionBefore (region, *newRegion, newRegion->begin ());
286- TypeConverter::SignatureConversion result (newRegion->getNumArguments ());
287- (void )getTypeConverter ()->convertSignatureArgs (
288- newRegion->getArgumentTypes (), result);
289- rewriter.applySignatureConversion (newRegion, result);
295+ Operation *newOp = rewriter.create (state);
296+ rewriter.replaceOp (op, newOp->getResults ());
297+ return success ();
290298 }
291-
292- Operation *newOp = rewriter.create (state);
293- rewriter.replaceOp (op, newOp->getResults ());
294- return success ();
295- }
299+ };
300+ } // namespace
296301
297302void spirv::populateMemorySpaceToStorageClassPatterns (
298303 spirv::MemorySpaceToStorageClassConverter &typeConverter,
@@ -308,58 +313,53 @@ namespace {
308313class MapMemRefStorageClassPass final
309314 : public impl::MapMemRefStorageClassBase<MapMemRefStorageClassPass> {
310315public:
311- explicit MapMemRefStorageClassPass () {
312- memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
313- }
316+ MapMemRefStorageClassPass () = default ;
317+
314318 explicit MapMemRefStorageClassPass (
315319 const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)
316320 : memorySpaceMap(memorySpaceMap) {}
317321
318- LogicalResult initializeOptions (StringRef options) override ;
319-
320- void runOnOperation () override ;
321-
322- private:
323- spirv::MemorySpaceToStorageClassMap memorySpaceMap;
324- };
325- } // namespace
322+ LogicalResult initializeOptions (StringRef options) override {
323+ if (failed (Pass::initializeOptions (options)))
324+ return failure ();
326325
327- LogicalResult MapMemRefStorageClassPass::initializeOptions (StringRef options) {
328- if (failed (Pass::initializeOptions (options)))
329- return failure ();
326+ if (clientAPI == " opencl" )
327+ memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
328+ else if (clientAPI != " vulkan" )
329+ return failure ();
330330
331- if (clientAPI == " opencl" ) {
332- memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
331+ return success ();
333332 }
334333
335- if (clientAPI != " vulkan" && clientAPI != " opencl" )
336- return failure ();
334+ void runOnOperation () override {
335+ MLIRContext *context = &getContext ();
336+ Operation *op = getOperation ();
337+
338+ if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv (op)) {
339+ spirv::TargetEnv targetEnv (attr);
340+ if (targetEnv.allows (spirv::Capability::Kernel)) {
341+ memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
342+ } else if (targetEnv.allows (spirv::Capability::Shader)) {
343+ memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
344+ }
345+ }
337346
338- return success ();
339- }
347+ std::unique_ptr<ConversionTarget> target =
348+ spirv::getMemorySpaceToStorageClassTarget (*context);
349+ spirv::MemorySpaceToStorageClassConverter converter (memorySpaceMap);
340350
341- void MapMemRefStorageClassPass::runOnOperation () {
342- MLIRContext *context = &getContext ();
343- Operation *op = getOperation ();
351+ RewritePatternSet patterns (context);
352+ spirv::populateMemorySpaceToStorageClassPatterns (converter, patterns);
344353
345- if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv (op)) {
346- spirv::TargetEnv targetEnv (attr);
347- if (targetEnv.allows (spirv::Capability::Kernel)) {
348- memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass;
349- } else if (targetEnv.allows (spirv::Capability::Shader)) {
350- memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass;
351- }
354+ if (failed (applyFullConversion (op, *target, std::move (patterns))))
355+ return signalPassFailure ();
352356 }
353357
354- auto target = spirv::getMemorySpaceToStorageClassTarget (*context);
355- spirv::MemorySpaceToStorageClassConverter converter (memorySpaceMap);
356-
357- RewritePatternSet patterns (context);
358- spirv::populateMemorySpaceToStorageClassPatterns (converter, patterns);
359-
360- if (failed (applyFullConversion (op, *target, std::move (patterns))))
361- return signalPassFailure ();
362- }
358+ private:
359+ spirv::MemorySpaceToStorageClassMap memorySpaceMap =
360+ spirv::mapMemorySpaceToVulkanStorageClass;
361+ };
362+ } // namespace
363363
364364std::unique_ptr<OperationPass<>> mlir::createMapMemRefStorageClassPass () {
365365 return std::make_unique<MapMemRefStorageClassPass>();
0 commit comments