1616#include " mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
1717#include " mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
1818#include " mlir/Dialect/Arith/Transforms/Passes.h"
19+ #include " mlir/Dialect/GPU/IR/GPUDialect.h"
1920#include " mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
2021#include " mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
2122#include " mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
@@ -40,6 +41,35 @@ using namespace mlir;
4041
4142namespace {
4243
44+ // / Map memRef memory space to SPIR-V storage class.
45+ void mapToMemRef (Operation *op, spirv::TargetEnvAttr &targetAttr) {
46+ spirv::TargetEnv targetEnv (targetAttr);
47+ bool targetEnvSupportsKernelCapability =
48+ targetEnv.allows (spirv::Capability::Kernel);
49+ spirv::MemorySpaceToStorageClassMap memorySpaceMap =
50+ targetEnvSupportsKernelCapability
51+ ? spirv::mapMemorySpaceToOpenCLStorageClass
52+ : spirv::mapMemorySpaceToVulkanStorageClass;
53+ spirv::MemorySpaceToStorageClassConverter converter (memorySpaceMap);
54+ spirv::convertMemRefTypesAndAttrs (op, converter);
55+ }
56+
57+ // / Populate patterns for each dialect.
58+ void populateConvertToSPIRVPatterns (SPIRVTypeConverter &typeConverter,
59+ ScfToSPIRVContext &scfToSPIRVContext,
60+ RewritePatternSet &patterns) {
61+ arith::populateCeilFloorDivExpandOpsPatterns (patterns);
62+ arith::populateArithToSPIRVPatterns (typeConverter, patterns);
63+ populateBuiltinFuncToSPIRVPatterns (typeConverter, patterns);
64+ populateFuncToSPIRVPatterns (typeConverter, patterns);
65+ populateGPUToSPIRVPatterns (typeConverter, patterns);
66+ index::populateIndexToSPIRVPatterns (typeConverter, patterns);
67+ populateMemRefToSPIRVPatterns (typeConverter, patterns);
68+ populateVectorToSPIRVPatterns (typeConverter, patterns);
69+ populateSCFToSPIRVPatterns (typeConverter, scfToSPIRVContext, patterns);
70+ ub::populateUBToSPIRVConversionPatterns (typeConverter, patterns);
71+ }
72+
4373// / A pass to perform the SPIR-V conversion.
4474struct ConvertToSPIRVPass final
4575 : impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
@@ -57,38 +87,46 @@ struct ConvertToSPIRVPass final
5787 if (runVectorUnrolling && failed (spirv::unrollVectorsInFuncBodies (op)))
5888 return signalPassFailure ();
5989
60- spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault (op);
61- std::unique_ptr<ConversionTarget> target =
62- SPIRVConversionTarget::get (targetAttr);
63- SPIRVTypeConverter typeConverter (targetAttr);
64- RewritePatternSet patterns (context);
65- ScfToSPIRVContext scfToSPIRVContext;
66-
67- // Map MemRef memory space to SPIR-V storage class.
68- spirv::TargetEnv targetEnv (targetAttr);
69- bool targetEnvSupportsKernelCapability =
70- targetEnv.allows (spirv::Capability::Kernel);
71- spirv::MemorySpaceToStorageClassMap memorySpaceMap =
72- targetEnvSupportsKernelCapability
73- ? spirv::mapMemorySpaceToOpenCLStorageClass
74- : spirv::mapMemorySpaceToVulkanStorageClass;
75- spirv::MemorySpaceToStorageClassConverter converter (memorySpaceMap);
76- spirv::convertMemRefTypesAndAttrs (op, converter);
90+ // Generic conversion.
91+ if (!convertGPUModules) {
92+ spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault (op);
93+ std::unique_ptr<ConversionTarget> target =
94+ SPIRVConversionTarget::get (targetAttr);
95+ SPIRVTypeConverter typeConverter (targetAttr);
96+ RewritePatternSet patterns (context);
97+ ScfToSPIRVContext scfToSPIRVContext;
98+ mapToMemRef (op, targetAttr);
99+ populateConvertToSPIRVPatterns (typeConverter, scfToSPIRVContext,
100+ patterns);
101+ if (failed (applyPartialConversion (op, *target, std::move (patterns))))
102+ return signalPassFailure ();
103+ return ;
104+ }
77105
78- // Populate patterns for each dialect.
79- arith::populateCeilFloorDivExpandOpsPatterns (patterns);
80- arith::populateArithToSPIRVPatterns (typeConverter, patterns);
81- populateBuiltinFuncToSPIRVPatterns (typeConverter, patterns);
82- populateFuncToSPIRVPatterns (typeConverter, patterns);
83- populateGPUToSPIRVPatterns (typeConverter, patterns);
84- index::populateIndexToSPIRVPatterns (typeConverter, patterns);
85- populateMemRefToSPIRVPatterns (typeConverter, patterns);
86- populateVectorToSPIRVPatterns (typeConverter, patterns);
87- populateSCFToSPIRVPatterns (typeConverter, scfToSPIRVContext, patterns);
88- ub::populateUBToSPIRVConversionPatterns (typeConverter, patterns);
89-
90- if (failed (applyPartialConversion (op, *target, std::move (patterns))))
91- return signalPassFailure ();
106+ // Clone each GPU kernel module for conversion, given that the GPU
107+ // launch op still needs the original GPU kernel module.
108+ SmallVector<Operation *, 1 > gpuModules;
109+ OpBuilder builder (context);
110+ op->walk ([&](gpu::GPUModuleOp gpuModule) {
111+ builder.setInsertionPoint (gpuModule);
112+ gpuModules.push_back (builder.clone (*gpuModule));
113+ });
114+ // Run conversion for each module independently as they can have
115+ // different TargetEnv attributes.
116+ for (Operation *gpuModule : gpuModules) {
117+ spirv::TargetEnvAttr targetAttr =
118+ spirv::lookupTargetEnvOrDefault (gpuModule);
119+ std::unique_ptr<ConversionTarget> target =
120+ SPIRVConversionTarget::get (targetAttr);
121+ SPIRVTypeConverter typeConverter (targetAttr);
122+ RewritePatternSet patterns (context);
123+ ScfToSPIRVContext scfToSPIRVContext;
124+ mapToMemRef (gpuModule, targetAttr);
125+ populateConvertToSPIRVPatterns (typeConverter, scfToSPIRVContext,
126+ patterns);
127+ if (failed (applyFullConversion (gpuModule, *target, std::move (patterns))))
128+ return signalPassFailure ();
129+ }
92130 }
93131};
94132
0 commit comments