|
4 | 4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | 6 |
|
| 7 | +#include "compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h" |
7 | 8 | #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" |
8 | 9 | #include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h" |
9 | 10 | #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" |
@@ -155,230 +156,6 @@ static NestedLayoutAttr createNestedLayout( |
155 | 156 | return layoutAttr; |
156 | 157 | } |
157 | 158 |
|
158 | | -static FailureOr<std::tuple<IREE::VectorExt::VectorLayoutInterface, |
159 | | - IREE::VectorExt::VectorLayoutInterface, |
160 | | - IREE::VectorExt::VectorLayoutInterface>> |
161 | | -getContractionLayout(IREE::GPU::MMAScheduleAttr schedule, |
162 | | - VectorContractOpInfo &opInfo, |
163 | | - linalg::LinalgOp contractOp) { |
164 | | - LLVM_DEBUG({ |
165 | | - llvm::dbgs() << "Getting mma layouts for:\n" << contractOp << "\n"; |
166 | | - llvm::dbgs() << "For schedule: " << schedule << "\n"; |
167 | | - }); |
168 | | - |
169 | | - int64_t rank = contractOp.getIteratorTypesArray().size(); |
170 | | - auto mmaAttr = |
171 | | - llvm::cast<IREE::GPU::MmaInterfaceAttr>(schedule.getIntrinsic()); |
172 | | - MLIRContext *context = schedule.getContext(); |
173 | | - |
174 | | - SmallVector<int64_t> bounds = contractOp.getStaticLoopRanges(); |
175 | | - if (llvm::any_of(bounds, |
176 | | - [](int64_t x) { return x == ShapedType::kDynamic; })) { |
177 | | - return failure(); |
178 | | - } |
179 | | - |
180 | | - if (!llvm::all_of(opInfo.getBatchDims(), |
181 | | - [&bounds](int64_t dim) { return bounds[dim] == 1; })) { |
182 | | - LLVM_DEBUG({ llvm::dbgs() << "non-unit batch dimension\n"; }); |
183 | | - return failure(); |
184 | | - } |
185 | | - |
186 | | - // Get the concrete nested layout for each matrix. Note that the struct |
187 | | - // MMASingleSubgroupLayout contains the partial layout for the |
188 | | - // canonical (M, K) x (K, N) -> (M, N) matmul form; while the specific |
189 | | - // contract op we are looking at right now may not be exactly in that form. |
190 | | - // So here we need to permute/transpose the canonical layout to match with |
191 | | - // the concrete contract op. |
192 | | - |
193 | | - // Note that no matter how we permute/transpose the input contraction |
194 | | - // problem, the way we view the hardware warps remain the same--that is, |
195 | | - // from the hardware's perspective, a single warp has the same warp ID no |
196 | | - // matter what part of the contraction it works on. Similarly here, we are |
197 | | - // delinearizing the linearized GPU hardware lane ID into a n-D concatenated |
198 | | - // logical warp+thread using the subgroup/thread basis, so the subgroup |
199 | | - // basis should remain the same for all A/B/C matrix. |
200 | | - |
201 | | - auto [intrinsicM, intrinsicN, intrinsicK] = mmaAttr.getMNKShape(); |
202 | | - |
203 | | - SmallVector<int64_t, 2> subgroupMBasis; |
204 | | - SmallVector<int64_t, 2> batchMSizes; |
205 | | - int64_t currMCount = schedule.getSubgroupMCount(); |
206 | | - |
207 | | - auto divideGreedily = [](int64_t availableSubgroups, int64_t dimSize, |
208 | | - int64_t minDimSize) -> std::pair<int64_t, int64_t> { |
209 | | - int64_t dividableDim = dimSize / minDimSize; |
210 | | - int64_t subgroupsUsed = std::gcd(availableSubgroups, dividableDim); |
211 | | - dividableDim /= subgroupsUsed; |
212 | | - int64_t batchesUsed = dividableDim; |
213 | | - return {subgroupsUsed, batchesUsed}; |
214 | | - }; |
215 | | - |
216 | | - // Greedily break up the M subgroup and batch counts along the "M" iteration |
217 | | - // bounds. We distribute as many residual subgroups as possible per M dim, |
218 | | - // and then divide the remaining along batch dims. The inner most M dim is |
219 | | - // always the one used for the intrinsic, meaning for a valid schedule, the |
220 | | - // computed batch counts and subgroup basis will satisfy totalMSize / |
221 | | - // intrinsicM = product(batchMSizes) * product(subgroupMBasis) |
222 | | - for (auto dim : opInfo.getMDims()) { |
223 | | - // Get the number of subgroups and batches used for this dimension based |
224 | | - // on the intrinsic size and the bound size. |
225 | | - int64_t subgroupsUsed, batchesUsed; |
226 | | - if (dim == opInfo.getMDims().back()) { |
227 | | - std::tie(subgroupsUsed, batchesUsed) = |
228 | | - divideGreedily(currMCount, bounds[dim], intrinsicM); |
229 | | - } else { |
230 | | - std::tie(subgroupsUsed, batchesUsed) = |
231 | | - divideGreedily(currMCount, bounds[dim], 1); |
232 | | - } |
233 | | - subgroupMBasis.push_back(subgroupsUsed); |
234 | | - batchMSizes.push_back(batchesUsed); |
235 | | - // Update available subgroup count. |
236 | | - currMCount /= subgroupsUsed; |
237 | | - } |
238 | | - |
239 | | - SmallVector<int64_t, 2> subgroupNBasis; |
240 | | - SmallVector<int64_t, 2> batchNSizes; |
241 | | - int64_t currNCount = schedule.getSubgroupNCount(); |
242 | | - |
243 | | - // Do the same for N dims. |
244 | | - for (auto dim : opInfo.getNDims()) { |
245 | | - // Get the number of subgroups and batches used for this dimension based |
246 | | - // on the intrinsic size and the bound size. |
247 | | - int64_t subgroupsUsed, batchesUsed; |
248 | | - if (dim == opInfo.getNDims().back()) { |
249 | | - std::tie(subgroupsUsed, batchesUsed) = |
250 | | - divideGreedily(currNCount, bounds[dim], intrinsicN); |
251 | | - } else { |
252 | | - std::tie(subgroupsUsed, batchesUsed) = |
253 | | - divideGreedily(currNCount, bounds[dim], 1); |
254 | | - } |
255 | | - subgroupNBasis.push_back(subgroupsUsed); |
256 | | - batchNSizes.push_back(batchesUsed); |
257 | | - // Update available subgroup count. |
258 | | - currNCount /= subgroupsUsed; |
259 | | - } |
260 | | - |
261 | | - SmallVector<int64_t> subgroupMStrides(subgroupMBasis.size()); |
262 | | - SmallVector<int64_t> subgroupNStrides(subgroupNBasis.size()); |
263 | | - |
264 | | - auto mDimVec = opInfo.getMDims(); |
265 | | - llvm::SmallDenseSet<int64_t> mDims(mDimVec.begin(), mDimVec.end()); |
266 | | - auto nDimVec = opInfo.getNDims(); |
267 | | - llvm::SmallDenseSet<int64_t> nDims(nDimVec.begin(), nDimVec.end()); |
268 | | - // Because we currently require all batch dimensions to be unit, the |
269 | | - // subgroup basis can be constructed from the M and N bases. To keep things |
270 | | - // simple, the current heuristic is to distribute the loop dimensions from |
271 | | - // outer to inner. |
272 | | - int64_t currStride = 1; |
273 | | - int64_t currM = subgroupMStrides.size() - 1; |
274 | | - int64_t currN = subgroupNStrides.size() - 1; |
275 | | - for (int64_t dim : llvm::reverse(llvm::seq<int64_t>(rank))) { |
276 | | - if (mDims.contains(dim)) { |
277 | | - subgroupMStrides[currM] = currStride; |
278 | | - currStride *= subgroupMBasis[currM]; |
279 | | - currM--; |
280 | | - continue; |
281 | | - } |
282 | | - |
283 | | - if (nDims.contains(dim)) { |
284 | | - subgroupNStrides[currN] = currStride; |
285 | | - currStride *= subgroupNBasis[currN]; |
286 | | - currN--; |
287 | | - continue; |
288 | | - } |
289 | | - } |
290 | | - |
291 | | - // C matrix layout |
292 | | - auto [m, n] = opInfo.getResultMNIndex(); |
293 | | - int64_t cRank = opInfo.getCRank(); |
294 | | - |
295 | | - // Get the M and N dims w.r.t. the dimensions of the C matrix. cMDims and |
296 | | - // cNDims are the M and N dimensions of the C matrix in the order they are |
297 | | - // iterated over in the contraction. |
298 | | - SmallVector<int64_t> cMDims = opInfo.outMDims; |
299 | | - SmallVector<int64_t> cNDims = opInfo.outNDims; |
300 | | - SmallVector<int64_t> cBatchSizes(cRank, 1); |
301 | | - SmallVector<int64_t> cSubgroupSizes(cRank, 1); |
302 | | - SmallVector<int64_t> cSubgroupStrides(cRank, 0); |
303 | | - for (auto [i, dim] : llvm::enumerate(cMDims)) { |
304 | | - cBatchSizes[dim] = batchMSizes[i]; |
305 | | - cSubgroupSizes[dim] = subgroupMBasis[i]; |
306 | | - cSubgroupStrides[dim] = subgroupMStrides[i]; |
307 | | - } |
308 | | - for (auto [i, dim] : llvm::enumerate(cNDims)) { |
309 | | - cBatchSizes[dim] = batchNSizes[i]; |
310 | | - cSubgroupSizes[dim] = subgroupNBasis[i]; |
311 | | - cSubgroupStrides[dim] = subgroupNStrides[i]; |
312 | | - } |
313 | | - |
314 | | - IREE::VectorExt::NestedLayoutAttr cLayout = createNestedLayout( |
315 | | - context, cRank, m, n, |
316 | | - /*subgroupCount=*/cSubgroupSizes, |
317 | | - /*subgroupStrides=*/cSubgroupStrides, |
318 | | - /*batchCount=*/cBatchSizes, |
319 | | - getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Acc)); |
320 | | - LLVM_DEBUG({ llvm::dbgs() << "C layout: " << cLayout << "\n"; }); |
321 | | - |
322 | | - // A matrix layout |
323 | | - auto [afm, bfn] = opInfo.getOperandMNIndex(); |
324 | | - auto [afk, bfk] = opInfo.getOperandKIndex(); |
325 | | - |
326 | | - int64_t aRank = opInfo.getARank(); |
327 | | - |
328 | | - SmallVector<int64_t> aMDims = opInfo.lhsMDims; |
329 | | - SmallVector<int64_t> aBatchSizes(aRank, 1); |
330 | | - SmallVector<int64_t> aSubgroupSizes(aRank, 1); |
331 | | - SmallVector<int64_t> aSubgroupStrides(aRank, 0); |
332 | | - for (auto [i, dim] : llvm::enumerate(aMDims)) { |
333 | | - aBatchSizes[dim] = batchMSizes[i]; |
334 | | - aSubgroupSizes[dim] = subgroupMBasis[i]; |
335 | | - aSubgroupStrides[dim] = subgroupMStrides[i]; |
336 | | - } |
337 | | - for (auto [kDim, lhsKDim] : |
338 | | - llvm::zip_equal(opInfo.getKDims(), opInfo.lhsKDim)) { |
339 | | - aBatchSizes[lhsKDim] = bounds[kDim]; |
340 | | - } |
341 | | - aBatchSizes[afk] = bounds[opInfo.getKDims().back()] / intrinsicK; |
342 | | - |
343 | | - IREE::VectorExt::NestedLayoutAttr aLayout = createNestedLayout( |
344 | | - context, aRank, afm, afk, |
345 | | - /*subgroupCount=*/aSubgroupSizes, |
346 | | - /*subgroupStrides=*/aSubgroupStrides, |
347 | | - /*batchCount=*/aBatchSizes, |
348 | | - getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Lhs)); |
349 | | - LLVM_DEBUG({ llvm::dbgs() << "A layout: " << aLayout << "\n"; }); |
350 | | - |
351 | | - int64_t bRank = opInfo.getBRank(); |
352 | | - |
353 | | - SmallVector<int64_t> bNDims = opInfo.rhsNDims; |
354 | | - SmallVector<int64_t> bBatchSizes(bRank, 1); |
355 | | - SmallVector<int64_t> bSubgroupSizes(bRank, 1); |
356 | | - SmallVector<int64_t> bSubgroupStrides(bRank, 0); |
357 | | - for (auto [i, dim] : llvm::enumerate(bNDims)) { |
358 | | - bBatchSizes[dim] = batchNSizes[i]; |
359 | | - bSubgroupSizes[dim] = subgroupNBasis[i]; |
360 | | - bSubgroupStrides[dim] = subgroupNStrides[i]; |
361 | | - } |
362 | | - for (auto [kDim, rhsKDim] : |
363 | | - llvm::zip_equal(opInfo.getKDims(), opInfo.rhsKDim)) { |
364 | | - bBatchSizes[rhsKDim] = bounds[kDim]; |
365 | | - } |
366 | | - bBatchSizes[bfk] = bounds[opInfo.getKDims().back()] / intrinsicK; |
367 | | - |
368 | | - IREE::VectorExt::NestedLayoutAttr bLayout = createNestedLayout( |
369 | | - context, bRank, bfk, bfn, |
370 | | - /*subgroupCount=*/bSubgroupSizes, |
371 | | - /*subgroupStrides=*/bSubgroupStrides, |
372 | | - /*batchCount=*/bBatchSizes, |
373 | | - getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Rhs)); |
374 | | - LLVM_DEBUG({ llvm::dbgs() << "B layout: " << bLayout << "\n"; }); |
375 | | - |
376 | | - std::tuple<VectorLayoutInterface, VectorLayoutInterface, |
377 | | - VectorLayoutInterface> |
378 | | - result = {aLayout, bLayout, cLayout}; |
379 | | - return result; |
380 | | -} |
381 | | - |
382 | 159 | static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule, |
383 | 160 | SmallVector<bool> promotedOperands, |
384 | 161 | RewriterBase &rewriter, |
|
0 commit comments