1010#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1111#include " iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
1212#include " iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h"
13- #include " iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.h"
1413#include " iree/compiler/Codegen/Utils/GPUUtils.h"
1514#include " iree/compiler/Codegen/Utils/Utils.h"
1615#include " iree/compiler/Dialect/HAL/IR/HALOps.h"
17- #include " llvm/ADT/APInt.h"
1816#include " llvm/ADT/STLExtras.h"
1917#include " mlir/Analysis/SliceAnalysis.h"
2018#include " mlir/Dialect/Linalg/IR/Linalg.h"
@@ -62,210 +60,6 @@ static bool isMatvecLike(linalg::LinalgOp linalgOp) {
6260 return true ;
6361}
6462
65- static LogicalResult
66- setWarpReductionConfig (IREE::GPU::TargetAttr target,
67- mlir::FunctionOpInterface entryPoint,
68- linalg::LinalgOp op) {
69- if (!target.supportsSubgroupShuffle ())
70- return failure ();
71-
72- SmallVector<unsigned > parallelDims;
73- SmallVector<unsigned > reductionDims;
74- op.getParallelDims (parallelDims);
75- op.getReductionDims (reductionDims);
76-
77- SmallVector<int64_t > bounds = op.getStaticLoopRanges ();
78- int64_t numParallelDims = op.getNumParallelLoops ();
79-
80- if (reductionDims.empty ())
81- return failure ();
82-
83- // Make sure reduction dimensions are static and innermost ones.
84- int64_t numDynamicReductionDims = 0 ;
85- for (unsigned dim : reductionDims) {
86- if (ShapedType::isDynamic (bounds[dim])) {
87- numDynamicReductionDims++;
88- }
89- if (dim < numParallelDims) {
90- return failure ();
91- }
92- }
93-
94- // Distribution of multi-dim masked writes currently aren't fully supported.
95- if (numDynamicReductionDims > 1 ) {
96- return failure ();
97- }
98-
99- if (op.getRegionOutputArgs ().size () != 1 )
100- return failure ();
101-
102- // Only support projected permutation, this could be extended to projected
103- // permutated with broadcast.
104- if (llvm::any_of (op.getDpsInputOperands (), [&](OpOperand *input) {
105- return !op.getMatchingIndexingMap (input).isProjectedPermutation ();
106- }))
107- return failure ();
108-
109- bool foundSingleReductionOutput = false ;
110- for (auto [index, initOpOperand] : llvm::enumerate (op.getDpsInitsMutable ())) {
111- // Only single combiner operations are supported for now.
112- SmallVector<Operation *> combinerOps;
113- if (matchReduction (op.getRegionOutputArgs (), index, combinerOps) &&
114- combinerOps.size () == 1 ) {
115- if (foundSingleReductionOutput)
116- return failure ();
117- foundSingleReductionOutput = true ;
118- continue ;
119- }
120- if (!op.getMatchingIndexingMap (&initOpOperand).isIdentity ())
121- return failure ();
122- }
123- if (!foundSingleReductionOutput)
124- return failure ();
125-
126- // Tile all the parallel dimension to 1.
127- SmallVector<unsigned > partitionedLoops =
128- cast<PartitionableLoopsInterface>(op.getOperation ())
129- .getPartitionableLoops (kNumMaxParallelDims );
130- size_t numLoops = partitionedLoops.empty () ? 0 : partitionedLoops.back () + 1 ;
131- SmallVector<int64_t > workgroupTileSizes (numLoops, 1 );
132-
133- // Without any bounds on dynamic reduction dims, we need specialization to
134- // get peak performance. For now, just use the warp size.
135- if (numDynamicReductionDims) {
136- SmallVector<int64_t > reductionTileSizes (op.getNumLoops (), 0 );
137- int64_t preferredSubgroupSize = target.getPreferredSubgroupSize ();
138- reductionTileSizes[reductionDims[0 ]] = preferredSubgroupSize;
139- TileSizesListType tileSizes;
140- tileSizes.emplace_back (std::move (workgroupTileSizes)); // Workgroup level
141- tileSizes.emplace_back (std::move (reductionTileSizes)); // Reduction level
142- std::array<int64_t , 3 > workgroupSize = {preferredSubgroupSize, 1 , 1 };
143- if (failed (setOpConfigAndEntryPointFnTranslation (
144- entryPoint, op, tileSizes, CodeGenPipeline::LLVMGPUWarpReduction,
145- workgroupSize))) {
146- return failure ();
147- }
148- return success ();
149- }
150-
151- int64_t reductionSize = 1 ;
152- for (int64_t dim : reductionDims)
153- reductionSize *= bounds[dim];
154-
155- int64_t subgroupSize = 0 ;
156- for (int s : target.getWgp ().getSubgroupSizeChoices ().asArrayRef ()) {
157- if (reductionSize % s == 0 ) {
158- subgroupSize = s;
159- break ;
160- }
161- }
162- if (subgroupSize == 0 )
163- return failure ();
164-
165- const Type elementType =
166- cast<ShapedType>(op.getDpsInitOperand (0 )->get ().getType ())
167- .getElementType ();
168- if (!elementType.isIntOrFloat ())
169- return failure ();
170- unsigned bitWidth = elementType.getIntOrFloatBitWidth ();
171- // Reduction distribution only supports 8/16/32 bit types now.
172- if (bitWidth != 32 && bitWidth != 16 && bitWidth != 8 )
173- return failure ();
174-
175- const unsigned largestLoadSizeInBits = 128 ;
176- unsigned vectorSize = largestLoadSizeInBits / bitWidth;
177- while ((reductionSize / vectorSize) % subgroupSize != 0 )
178- vectorSize /= 2 ;
179-
180- // Deduce the workgroup size we should use for reduction. Currently a
181- // workgroup processes all elements in reduction dimensions. Need to make sure
182- // the workgroup size we use can divide the total reduction size, and it's
183- // also within hardware limitations.
184- const int64_t maxWorkgroupSize = 1024 ;
185- int64_t groupSize = reductionSize / vectorSize;
186- if (groupSize > maxWorkgroupSize) {
187- groupSize = llvm::APIntOps::GreatestCommonDivisor (
188- {64 , uint64_t (groupSize)}, {64 , uint64_t (maxWorkgroupSize)})
189- .getZExtValue ();
190- }
191-
192- // Then we need to strike a balance--
193- // 1) parallel dimensions are distributed to workgroups. If there are many
194- // workgroups dispatched, we'd want to have each GPU core hosting multiple
195- // of them for occupancy.
196- // 2) we want each thread to read quite a few 128-bit vectors for better
197- // memory cache behavior.
198- // Both means we cannot use a too large workgroup size.
199-
200- std::optional<int64_t > parallelSize = 1 ;
201- for (int64_t dim : parallelDims) {
202- if (ShapedType::isDynamic (bounds[dim])) {
203- parallelSize = std::nullopt ;
204- break ;
205- }
206- *parallelSize *= bounds[dim];
207- }
208- // Total parallel size that can fill the GPU with enough workgorups.
209- // TODO: query from the target device; roughly 2x hardware compute unit.
210- const int parallelThreshold = 256 ;
211- // How many 128-bit vectors each thread should at least read.
212- const int targetVectorCount = 8 ;
213- while (parallelSize && *parallelSize > parallelThreshold &&
214- (groupSize / 2 ) % subgroupSize == 0 &&
215- reductionSize / (groupSize * vectorSize) < targetVectorCount) {
216- // Use less subgroups per workgroup..
217- groupSize /= 2 ;
218- // in order to host more workgroups per hardware compute unit.
219- *parallelSize /= 2 ;
220- }
221-
222- // Current warp reduction pattern is a two step butterfly warp reduce.
223- // First, do warp reductions along multiple subgroups.
224- // Second, reduce results from multiple subgroups using single warp reduce.
225- // The final warp reduce requires subgroup count <= subgroup size to work.
226- if ((groupSize / subgroupSize) > subgroupSize)
227- return failure ();
228-
229- // With just one subgroup per workgroup, make each subgroup do more work and
230- // process a few reductions (rows) along the last parallel dimension.
231- if (llvm::none_of (bounds, ShapedType::isDynamic) && isMatvecLike (op)) {
232- int64_t lastParallelBound = bounds[parallelDims.back ()];
233- int64_t numParallelReductions = 1 ;
234- const int64_t maxParallelFactor = groupSize / 4 ;
235- for (int64_t parallelFactor = 2 ;
236- (parallelFactor < maxParallelFactor) &&
237- (lastParallelBound % parallelFactor == 0 ) &&
238- (lastParallelBound > parallelFactor);
239- parallelFactor *= 2 ) {
240- numParallelReductions = parallelFactor;
241- }
242- workgroupTileSizes.back () = numParallelReductions;
243- }
244-
245- std::array<int64_t , 3 > workgroupSize = {groupSize, 1 , 1 };
246- SmallVector<int64_t > reductionTileSizes (op.getNumLoops (), 0 );
247- int64_t remainingGroupSize = groupSize;
248- for (int i = reductionDims.size () - 1 ; i >= 0 ; --i) {
249- int64_t dim = reductionDims[i];
250- int64_t bound = bounds[dim];
251- if (i == reductionDims.size () - 1 )
252- bound /= vectorSize;
253- APInt size = llvm::APIntOps::GreatestCommonDivisor (
254- {64 , uint64_t (remainingGroupSize)}, {64 , uint64_t (bound)});
255- reductionTileSizes[dim] = size.getSExtValue ();
256- if (i == reductionDims.size () - 1 )
257- reductionTileSizes[dim] *= vectorSize;
258- remainingGroupSize /= size.getSExtValue ();
259- }
260- TileSizesListType tileSizes;
261- tileSizes.emplace_back (std::move (workgroupTileSizes)); // Workgroup level
262- tileSizes.emplace_back (std::move (reductionTileSizes)); // Reduction level
263- return setOpConfigAndEntryPointFnTranslation (
264- entryPoint, op, tileSizes, CodeGenPipeline::LLVMGPUWarpReduction,
265- workgroupSize, subgroupSize);
266- return success ();
267- }
268-
26963// ===----------------------------------------------------------------------===//
27064// Root Configuration
27165// ===----------------------------------------------------------------------===//
@@ -287,9 +81,6 @@ static LogicalResult setRootConfig(IREE::GPU::TargetAttr target,
28781 target, entryPointFn, computeOp))) {
28882 return success ();
28983 }
290- if (succeeded (setWarpReductionConfig (target, entryPointFn, linalgOp))) {
291- return success ();
292- }
29384 // TODO: Add configurations for matmul here too.
29485 if (succeeded (IREE::GPU::setTileAndFuseLoweringConfig (target, entryPointFn,
29586 computeOp))) {
0 commit comments