Skip to content

Commit 07203e6

Browse files
[NFC] Expose internal LLVMGPU APIs for vector_distribute (iree-org#21161)
This allows external systems to make use of IREE's vector_distribute functionality in a composable manner. Signed-off-by: Nicolas Vasilache <[email protected]>
1 parent 3beeb26 commit 07203e6

File tree

9 files changed

+410
-266
lines changed

9 files changed

+410
-266
lines changed

compiler/src/iree/compiler/Codegen/LLVMGPU/InternalAPI.h

Whitespace-only changes.

compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp

Lines changed: 1 addition & 224 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
#include "compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
78
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
89
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
910
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
@@ -155,230 +156,6 @@ static NestedLayoutAttr createNestedLayout(
155156
return layoutAttr;
156157
}
157158

158-
static FailureOr<std::tuple<IREE::VectorExt::VectorLayoutInterface,
159-
IREE::VectorExt::VectorLayoutInterface,
160-
IREE::VectorExt::VectorLayoutInterface>>
161-
getContractionLayout(IREE::GPU::MMAScheduleAttr schedule,
162-
VectorContractOpInfo &opInfo,
163-
linalg::LinalgOp contractOp) {
164-
LLVM_DEBUG({
165-
llvm::dbgs() << "Getting mma layouts for:\n" << contractOp << "\n";
166-
llvm::dbgs() << "For schedule: " << schedule << "\n";
167-
});
168-
169-
int64_t rank = contractOp.getIteratorTypesArray().size();
170-
auto mmaAttr =
171-
llvm::cast<IREE::GPU::MmaInterfaceAttr>(schedule.getIntrinsic());
172-
MLIRContext *context = schedule.getContext();
173-
174-
SmallVector<int64_t> bounds = contractOp.getStaticLoopRanges();
175-
if (llvm::any_of(bounds,
176-
[](int64_t x) { return x == ShapedType::kDynamic; })) {
177-
return failure();
178-
}
179-
180-
if (!llvm::all_of(opInfo.getBatchDims(),
181-
[&bounds](int64_t dim) { return bounds[dim] == 1; })) {
182-
LLVM_DEBUG({ llvm::dbgs() << "non-unit batch dimension\n"; });
183-
return failure();
184-
}
185-
186-
// Get the concrete nested layout for each matrix. Note that the struct
187-
// MMASingleSubgroupLayout contains the partial layout for the
188-
// canonical (M, K) x (K, N) -> (M, N) matmul form; while the specific
189-
// contract op we are looking at right now may not be exactly in that form.
190-
// So here we need to permute/transpose the canonical layout to match with
191-
// the concrete contract op.
192-
193-
// Note that no matter how we permute/transpose the input contraction
194-
// problem, the way we view the hardware warps remain the same--that is,
195-
// from the hardware's perspective, a single warp has the same warp ID no
196-
// matter what part of the contraction it works on. Similarly here, we are
197-
// delinearizing the linearized GPU hardware lane ID into a n-D concatenated
198-
// logical warp+thread using the subgroup/thread basis, so the subgroup
199-
// basis should remain the same for all A/B/C matrix.
200-
201-
auto [intrinsicM, intrinsicN, intrinsicK] = mmaAttr.getMNKShape();
202-
203-
SmallVector<int64_t, 2> subgroupMBasis;
204-
SmallVector<int64_t, 2> batchMSizes;
205-
int64_t currMCount = schedule.getSubgroupMCount();
206-
207-
auto divideGreedily = [](int64_t availableSubgroups, int64_t dimSize,
208-
int64_t minDimSize) -> std::pair<int64_t, int64_t> {
209-
int64_t dividableDim = dimSize / minDimSize;
210-
int64_t subgroupsUsed = std::gcd(availableSubgroups, dividableDim);
211-
dividableDim /= subgroupsUsed;
212-
int64_t batchesUsed = dividableDim;
213-
return {subgroupsUsed, batchesUsed};
214-
};
215-
216-
// Greedily break up the M subgroup and batch counts along the "M" iteration
217-
// bounds. We distribute as many residual subgroups as possible per M dim,
218-
// and then divide the remaining along batch dims. The inner most M dim is
219-
// always the one used for the intrinsic, meaning for a valid schedule, the
220-
// computed batch counts and subgroup basis will satisfy totalMSize /
221-
// intrinsicM = product(batchMSizes) * product(subgroupMBasis)
222-
for (auto dim : opInfo.getMDims()) {
223-
// Get the number of subgroups and batches used for this dimension based
224-
// on the intrinsic size and the bound size.
225-
int64_t subgroupsUsed, batchesUsed;
226-
if (dim == opInfo.getMDims().back()) {
227-
std::tie(subgroupsUsed, batchesUsed) =
228-
divideGreedily(currMCount, bounds[dim], intrinsicM);
229-
} else {
230-
std::tie(subgroupsUsed, batchesUsed) =
231-
divideGreedily(currMCount, bounds[dim], 1);
232-
}
233-
subgroupMBasis.push_back(subgroupsUsed);
234-
batchMSizes.push_back(batchesUsed);
235-
// Update available subgroup count.
236-
currMCount /= subgroupsUsed;
237-
}
238-
239-
SmallVector<int64_t, 2> subgroupNBasis;
240-
SmallVector<int64_t, 2> batchNSizes;
241-
int64_t currNCount = schedule.getSubgroupNCount();
242-
243-
// Do the same for N dims.
244-
for (auto dim : opInfo.getNDims()) {
245-
// Get the number of subgroups and batches used for this dimension based
246-
// on the intrinsic size and the bound size.
247-
int64_t subgroupsUsed, batchesUsed;
248-
if (dim == opInfo.getNDims().back()) {
249-
std::tie(subgroupsUsed, batchesUsed) =
250-
divideGreedily(currNCount, bounds[dim], intrinsicN);
251-
} else {
252-
std::tie(subgroupsUsed, batchesUsed) =
253-
divideGreedily(currNCount, bounds[dim], 1);
254-
}
255-
subgroupNBasis.push_back(subgroupsUsed);
256-
batchNSizes.push_back(batchesUsed);
257-
// Update available subgroup count.
258-
currNCount /= subgroupsUsed;
259-
}
260-
261-
SmallVector<int64_t> subgroupMStrides(subgroupMBasis.size());
262-
SmallVector<int64_t> subgroupNStrides(subgroupNBasis.size());
263-
264-
auto mDimVec = opInfo.getMDims();
265-
llvm::SmallDenseSet<int64_t> mDims(mDimVec.begin(), mDimVec.end());
266-
auto nDimVec = opInfo.getNDims();
267-
llvm::SmallDenseSet<int64_t> nDims(nDimVec.begin(), nDimVec.end());
268-
// Because we currently require all batch dimensions to be unit, the
269-
// subgroup basis can be constructed from the M and N bases. To keep things
270-
// simple, the current heuristic is to distribute the loop dimensions from
271-
// outer to inner.
272-
int64_t currStride = 1;
273-
int64_t currM = subgroupMStrides.size() - 1;
274-
int64_t currN = subgroupNStrides.size() - 1;
275-
for (int64_t dim : llvm::reverse(llvm::seq<int64_t>(rank))) {
276-
if (mDims.contains(dim)) {
277-
subgroupMStrides[currM] = currStride;
278-
currStride *= subgroupMBasis[currM];
279-
currM--;
280-
continue;
281-
}
282-
283-
if (nDims.contains(dim)) {
284-
subgroupNStrides[currN] = currStride;
285-
currStride *= subgroupNBasis[currN];
286-
currN--;
287-
continue;
288-
}
289-
}
290-
291-
// C matrix layout
292-
auto [m, n] = opInfo.getResultMNIndex();
293-
int64_t cRank = opInfo.getCRank();
294-
295-
// Get the M and N dims w.r.t. the dimensions of the C matrix. cMDims and
296-
// cNDims are the M and N dimensions of the C matrix in the order they are
297-
// iterated over in the contraction.
298-
SmallVector<int64_t> cMDims = opInfo.outMDims;
299-
SmallVector<int64_t> cNDims = opInfo.outNDims;
300-
SmallVector<int64_t> cBatchSizes(cRank, 1);
301-
SmallVector<int64_t> cSubgroupSizes(cRank, 1);
302-
SmallVector<int64_t> cSubgroupStrides(cRank, 0);
303-
for (auto [i, dim] : llvm::enumerate(cMDims)) {
304-
cBatchSizes[dim] = batchMSizes[i];
305-
cSubgroupSizes[dim] = subgroupMBasis[i];
306-
cSubgroupStrides[dim] = subgroupMStrides[i];
307-
}
308-
for (auto [i, dim] : llvm::enumerate(cNDims)) {
309-
cBatchSizes[dim] = batchNSizes[i];
310-
cSubgroupSizes[dim] = subgroupNBasis[i];
311-
cSubgroupStrides[dim] = subgroupNStrides[i];
312-
}
313-
314-
IREE::VectorExt::NestedLayoutAttr cLayout = createNestedLayout(
315-
context, cRank, m, n,
316-
/*subgroupCount=*/cSubgroupSizes,
317-
/*subgroupStrides=*/cSubgroupStrides,
318-
/*batchCount=*/cBatchSizes,
319-
getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Acc));
320-
LLVM_DEBUG({ llvm::dbgs() << "C layout: " << cLayout << "\n"; });
321-
322-
// A matrix layout
323-
auto [afm, bfn] = opInfo.getOperandMNIndex();
324-
auto [afk, bfk] = opInfo.getOperandKIndex();
325-
326-
int64_t aRank = opInfo.getARank();
327-
328-
SmallVector<int64_t> aMDims = opInfo.lhsMDims;
329-
SmallVector<int64_t> aBatchSizes(aRank, 1);
330-
SmallVector<int64_t> aSubgroupSizes(aRank, 1);
331-
SmallVector<int64_t> aSubgroupStrides(aRank, 0);
332-
for (auto [i, dim] : llvm::enumerate(aMDims)) {
333-
aBatchSizes[dim] = batchMSizes[i];
334-
aSubgroupSizes[dim] = subgroupMBasis[i];
335-
aSubgroupStrides[dim] = subgroupMStrides[i];
336-
}
337-
for (auto [kDim, lhsKDim] :
338-
llvm::zip_equal(opInfo.getKDims(), opInfo.lhsKDim)) {
339-
aBatchSizes[lhsKDim] = bounds[kDim];
340-
}
341-
aBatchSizes[afk] = bounds[opInfo.getKDims().back()] / intrinsicK;
342-
343-
IREE::VectorExt::NestedLayoutAttr aLayout = createNestedLayout(
344-
context, aRank, afm, afk,
345-
/*subgroupCount=*/aSubgroupSizes,
346-
/*subgroupStrides=*/aSubgroupStrides,
347-
/*batchCount=*/aBatchSizes,
348-
getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Lhs));
349-
LLVM_DEBUG({ llvm::dbgs() << "A layout: " << aLayout << "\n"; });
350-
351-
int64_t bRank = opInfo.getBRank();
352-
353-
SmallVector<int64_t> bNDims = opInfo.rhsNDims;
354-
SmallVector<int64_t> bBatchSizes(bRank, 1);
355-
SmallVector<int64_t> bSubgroupSizes(bRank, 1);
356-
SmallVector<int64_t> bSubgroupStrides(bRank, 0);
357-
for (auto [i, dim] : llvm::enumerate(bNDims)) {
358-
bBatchSizes[dim] = batchNSizes[i];
359-
bSubgroupSizes[dim] = subgroupNBasis[i];
360-
bSubgroupStrides[dim] = subgroupNStrides[i];
361-
}
362-
for (auto [kDim, rhsKDim] :
363-
llvm::zip_equal(opInfo.getKDims(), opInfo.rhsKDim)) {
364-
bBatchSizes[rhsKDim] = bounds[kDim];
365-
}
366-
bBatchSizes[bfk] = bounds[opInfo.getKDims().back()] / intrinsicK;
367-
368-
IREE::VectorExt::NestedLayoutAttr bLayout = createNestedLayout(
369-
context, bRank, bfk, bfn,
370-
/*subgroupCount=*/bSubgroupSizes,
371-
/*subgroupStrides=*/bSubgroupStrides,
372-
/*batchCount=*/bBatchSizes,
373-
getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Rhs));
374-
LLVM_DEBUG({ llvm::dbgs() << "B layout: " << bLayout << "\n"; });
375-
376-
std::tuple<VectorLayoutInterface, VectorLayoutInterface,
377-
VectorLayoutInterface>
378-
result = {aLayout, bLayout, cLayout};
379-
return result;
380-
}
381-
382159
static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
383160
SmallVector<bool> promotedOperands,
384161
RewriterBase &rewriter,

compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1010
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
1111
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
12+
#include "iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUUtils.h"
1213
#include "mlir/Analysis/SliceAnalysis.h"
1314
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1415
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -26,35 +27,30 @@ namespace mlir::iree_compiler {
2627
#define GEN_PASS_DEF_LLVMGPUVECTORDISTRIBUTEPASS
2728
#include "iree/compiler/Codegen/LLVMGPU/Passes.h.inc"
2829

29-
namespace {
30-
31-
class ContractionVectorLayoutOptions : public VectorLayoutOptions {
32-
public:
33-
ContractionVectorLayoutOptions(Operation *root, Value laneId,
34-
int64_t subgroupSize)
35-
: VectorLayoutOptions(root), patterns(root->getContext()) {
36-
populateGPUDistributionPatterns(patterns);
37-
populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId,
38-
subgroupSize);
39-
populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns);
30+
ContractionVectorLayoutOptions::ContractionVectorLayoutOptions(
31+
Operation *root, Value laneId, int64_t subgroupSize)
32+
: VectorLayoutOptions(root), patterns(root->getContext()) {
33+
populateGPUDistributionPatterns(patterns);
34+
populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId, subgroupSize);
35+
populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns);
36+
}
37+
38+
RewritePatternSet &ContractionVectorLayoutOptions::getPatterns() {
39+
return patterns;
40+
}
41+
42+
VectorLayoutInterface
43+
ContractionVectorLayoutOptions::getDefaultLayout(VectorType type) const {
44+
// We only allow a default layout for 0-d vectors for now.
45+
if (type.getRank() > 0) {
46+
return VectorLayoutInterface();
4047
}
48+
ArrayRef<int64_t> empty = {};
49+
return IREE::VectorExt::NestedLayoutAttr::get(
50+
type.getContext(), empty, empty, empty, empty, empty, empty, empty);
51+
}
4152

42-
RewritePatternSet &getPatterns() { return patterns; }
43-
44-
VectorLayoutInterface getDefaultLayout(VectorType type) const override {
45-
// We only allow a default layout for 0-d vectors for now.
46-
if (type.getRank() > 0) {
47-
return VectorLayoutInterface();
48-
}
49-
ArrayRef<int64_t> empty = {};
50-
return IREE::VectorExt::NestedLayoutAttr::get(
51-
type.getContext(), empty, empty, empty, empty, empty, empty, empty);
52-
}
53-
54-
private:
55-
RewritePatternSet patterns;
56-
};
57-
53+
namespace {
5854
struct LLVMGPUVectorDistributePass final
5955
: impl::LLVMGPUVectorDistributePassBase<LLVMGPUVectorDistributePass> {
6056
void getDependentDialects(DialectRegistry &registry) const override {

compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,16 +1458,6 @@ transform_dialect::PrefetchSharedMemoryCopiesOp::applyToOne(
14581458
return DiagnosedSilenceableFailure::success();
14591459
}
14601460

1461-
class TransformVectorLayoutOptions : public VectorLayoutOptions {
1462-
public:
1463-
TransformVectorLayoutOptions(Operation *root, bool fullConversion)
1464-
: VectorLayoutOptions(root, fullConversion) {}
1465-
1466-
VectorLayoutInterface getDefaultLayout(VectorType type) const override {
1467-
return VectorLayoutInterface();
1468-
}
1469-
};
1470-
14711461
DiagnosedSilenceableFailure
14721462
transform_dialect::AMDGPUDistributeVectorsOp::applyToOne(
14731463
transform::TransformRewriter &rewriter, mlir::FunctionOpInterface target,

compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#ifndef IREE_COMPILER_CODEGEN_LLVMGPU_TRANSFORMEXTENSIONS_LLVMGPUEXTENSIONS_H_
88
#define IREE_COMPILER_CODEGEN_LLVMGPU_TRANSFORMEXTENSIONS_LLVMGPUEXTENSIONS_H_
99

10+
#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
1011
#include "mlir/Bytecode/BytecodeOpInterface.h"
1112
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1213
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
@@ -31,6 +32,16 @@ class WarpExecuteOnLane0Op;
3132

3233
namespace mlir::iree_compiler {
3334

35+
class TransformVectorLayoutOptions : public VectorLayoutOptions {
36+
public:
37+
TransformVectorLayoutOptions(Operation *root, bool fullConversion)
38+
: VectorLayoutOptions(root, fullConversion) {}
39+
40+
VectorLayoutInterface getDefaultLayout(VectorType type) const override {
41+
return VectorLayoutInterface();
42+
}
43+
};
44+
3445
/// Registers Flow transformations that require IREE-specific information into
3546
/// the transform dialect.
3647
void registerTransformDialectLLVMGPUExtension(DialectRegistry &registry);

0 commit comments

Comments
 (0)