Skip to content

Commit 5612a26

Browse files
authored
Lower linalg.copy to direct global load (#20568)
## Summary This PR sets the foundation for using `global_load_lds` instruction to load values from global to LDS memory. The pipeline is as follows: * Only convert `linalg.copy` emitted in `PromoteGPUMatMulOperands`. When it sees fit, insert a different attribute (`#iree_gpu.use_global_load_dma`) to `linalg.copy` to tag it along the pipeline. * Tagged `linalg.copy` will not be decomposed/tiled until bufferization. * after distributed to threads and bufferization, the tagged `linalg.copy` will then be lowered to a sequence of code responsible for subgroup-coalesced loading op `iree_gpu.global_load_dma`. * `iree_gpu.global_load_dma` will be mapped to `amdgpu.gather_to_lds` op, which will mapped to corresponding rocdl op. * Disable padding to reduce bank conflict pass because the destination workgroup memory has to be contiguous. ## Lowering `linalg.copy` After bufferization and distribute to threads, tagged `linalg.copy` still exists in the IR: ``` linalg.copy {lowering_config = #iree_gpu.use_global_load_dma} ins(%subview_12 : memref<64x128xi8, strided<[256, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>) outs(%alloc_4 : memref<64x128xi8, #gpu.address_space<workgroup>>) ``` Note that this `linalg.copy` is kept in the thread's code. The op itself is then converted into a `for loop`, in which subgroup of threads loads coalesced chunk of values. For example, assume there are N subgroups loading from `tensor<a x b x c>`: * then `i`-th subgruop will load a sub tensor of size `[a/N, b, c]`, so each slice is consecutive. * At this moment, assume row-major, and only tile the outermost dim. * The reason right now we are only dealing with `linalg.copy` emitted by `GPUPromoteMatmulOperands` is that we know the destination is allocated contiguously. * TODO: expand to any memref slices. * given `gpu.subgroup_id` and `gpu.lane_id`, each thread calculates the consecutive data chunk the subgroup the thread belongs to is responsible to load: * the chunk indices is the delinearized indices of the input tensor, from: * `affine.delinearize_index[gpu.subgroup_id * (num_elems_of(tensor) / num_subgroups)]`, to * `affine.delinearize_index[(gpu.subgroup_id + 1) * (num_elems_of(tensor) / num_subgroups) - 1]` * Assume each subgroup will load `n` values from linearized index `[N_f, N_b]`, then thread with lane id `i` will try to load: `iter = 0 to n : N_f + subgroup_size * iter + (i - 1)` . Then it will be converted to something like the following (in the example, assume `workgroup size = 256`, `subgroup_size = 64`, loading `64x128xi8`): ```miler scf.for %indvar = %c0 to %c32 step %c1 { ;; thread-specific gathering address from global address %17 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 2048 + s2 * 64)>()[%lane_id, %subgroup_id, %indvar] %18:2 = affine.delinearize_index %17 into (128, 64) : index, index ;; this iteration's base storing index %19 = affine.apply affine_map<()[s0, s1] -> (s0 * 2048 + s1 * 64)>()[%subgroup_id, %indvar] %20:2 = affine.delinearize_index %19 into (128, 64) : index, index iree_gpu.global_load_dma %subview_13[%18#0, %18#1] -> %alloc_5[%20#0, %20#1] : memref<128x64xi8, strided<[256, 1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>> -> memref<128x64xi8, #gpu.address_space<workgroup>> } ;; if there are residual elements (subgroup_copy_region_size % subgroup_size != 0), copy residual elements here gpu.barrier ``` ## Dependent PRs: * design doc: https://hackmd.io/N0RitxPzT9GPhM0jEPtOCg?view * upstream changes required: * llvm/llvm-project#133498 * llvm/llvm-project#136405 * llvm/llvm-project#137671 * llvm/llvm-project#137425 * #20800 (review) --------- Signed-off-by: Alan Li <[email protected]>
1 parent bb906c1 commit 5612a26

26 files changed

+672
-10
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ iree_compiler_cc_library(
7474
"GPUGeneralizeNamedOps.cpp",
7575
"GPUGreedilyDistributeToThreads.cpp",
7676
"GPUInferMemorySpace.cpp",
77+
"GPULowerToGlobalLoads.cpp",
7778
"GPULowerToUKernels.cpp",
7879
"GPUMultiBuffering.cpp",
7980
"GPUNestedLayoutDistributionPatterns.cpp",
@@ -146,6 +147,7 @@ iree_compiler_cc_library(
146147
"@llvm-project//mlir:LoopLikeInterface",
147148
"@llvm-project//mlir:MemRefDialect",
148149
"@llvm-project//mlir:MemRefTransforms",
150+
"@llvm-project//mlir:MemRefUtils",
149151
"@llvm-project//mlir:NVGPUDialect",
150152
"@llvm-project//mlir:Pass",
151153
"@llvm-project//mlir:Rewrite",

compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ iree_cc_library(
6767
"GPUGeneralizeNamedOps.cpp"
6868
"GPUGreedilyDistributeToThreads.cpp"
6969
"GPUInferMemorySpace.cpp"
70+
"GPULowerToGlobalLoads.cpp"
7071
"GPULowerToUKernels.cpp"
7172
"GPUMultiBuffering.cpp"
7273
"GPUNestedLayoutDistributionPatterns.cpp"
@@ -116,6 +117,7 @@ iree_cc_library(
116117
MLIRLoopLikeInterface
117118
MLIRMemRefDialect
118119
MLIRMemRefTransforms
120+
MLIRMemRefUtils
119121
MLIRNVGPUDialect
120122
MLIRPass
121123
MLIRRewrite

compiler/src/iree/compiler/Codegen/Common/GPU/GPUGreedilyDistributeToThreads.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
8+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
89
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
910
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
1011
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
@@ -135,6 +136,10 @@ static void processRegion(RewriterBase &rewriter, Region *region) {
135136

136137
// If an op implements the tiling interface, try to greedily tile + fuse.
137138
if (auto tilableOp = dyn_cast<TilingInterface>(op)) {
139+
// Do not distribute to threads of an op wants to use DMA.
140+
if (auto useDMAConfig =
141+
getLoweringConfig<IREE::GPU::UseGlobalLoadDMAAttr>(op))
142+
continue;
138143
tileToThreads(rewriter, tilableOp);
139144
continue;
140145
}

compiler/src/iree/compiler/Codegen/Common/GPU/GPUInferMemorySpace.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
8+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
89
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
910
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
1011
#include "llvm/ADT/STLExtras.h"
@@ -38,6 +39,11 @@ bool isDefinitelyShared(bufferization::AllocTensorOp alloc) {
3839
// thread distributed `scf.forall` op. All other shared allocations are
3940
// expected to be properly indicated in advance.
4041
for (auto user : alloc->getUsers()) {
42+
if (isa<linalg::CopyOp>(user) &&
43+
getLoweringConfig<IREE::GPU::UseGlobalLoadDMAAttr>(user)) {
44+
continue;
45+
}
46+
4147
auto forallOp = dyn_cast<scf::ForallOp>(user);
4248
if (!forallOp ||
4349
!forallOpHasMappingType<gpu::GPUThreadMappingAttr,
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include <cstdint>
8+
#include <numeric>
9+
#include <optional>
10+
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
11+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
12+
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
13+
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
14+
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
15+
#include "llvm/ADT/ArrayRef.h"
16+
#include "llvm/Support/Debug.h"
17+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
18+
#include "mlir/Dialect/Arith/IR/Arith.h"
19+
#include "mlir/Dialect/Arith/Utils/Utils.h"
20+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
21+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
22+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
23+
#include "mlir/Dialect/SCF/IR/SCF.h"
24+
#include "mlir/Dialect/SCF/Utils/Utils.h"
25+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
26+
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
27+
#include "mlir/IR/BuiltinAttributes.h"
28+
#include "mlir/IR/OpDefinition.h"
29+
#include "mlir/Support/LLVM.h"
30+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31+
32+
#define DEBUG_TYPE "iree-codegen-gpu-lower-to-global-loads"
33+
#define LDBG(X) LLVM_DEBUG(llvm::dbgs() << X << "\n")
34+
35+
namespace mlir::iree_compiler {
36+
37+
#define GEN_PASS_DEF_GPULOWERTOGLOBALLOADSPASS
38+
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"
39+
40+
static constexpr int kNumBitsPerCopy = 32;
41+
42+
static LogicalResult
43+
distributeLinalgCopyToThreads(RewriterBase &rewriter, linalg::CopyOp copy,
44+
ArrayRef<int64_t> workgroupSize,
45+
int64_t subgroupSize) {
46+
LDBG("==== distributing op: ");
47+
LDBG(*copy);
48+
Location loc = copy.getLoc();
49+
50+
// The linalg.copy we are dealing with represents a region we need to copy to
51+
// workgroup memory. Assume there are N threads in the workgroup, then there
52+
// are `num_subgroups = N / gpu.subgroup_size` subgroups in the workgroup.
53+
//
54+
// So we are slicing up the target memref into `num_subgroups` consecutive
55+
// slices, and threads in the same subgroup will copy their slice to workgroup
56+
// memory slice.
57+
58+
// Get the copy size:
59+
auto copyMemRefType = cast<MemRefType>(copy.getOperand(1).getType());
60+
if (!memref::isStaticShapeAndContiguousRowMajor(copyMemRefType)) {
61+
return rewriter.notifyMatchFailure(copy,
62+
"Copy to non-static or non-contiguous, "
63+
"non-row major memref.");
64+
}
65+
int64_t rank = copyMemRefType.getRank();
66+
SmallVector<OpFoldResult> tileSize(rank - 1, rewriter.getIndexAttr(1));
67+
68+
int64_t elementBitWidth = copyMemRefType.getElementTypeBitWidth();
69+
if (kNumBitsPerCopy % elementBitWidth != 0) {
70+
return rewriter.notifyMatchFailure(copy, "Copy size is not a multiple of "
71+
"element bit width.");
72+
}
73+
int64_t elementsPerCopy = kNumBitsPerCopy / elementBitWidth;
74+
75+
// Divide the copy by subgroup, and load linearly.
76+
assert(workgroupSize[0] % subgroupSize == 0);
77+
78+
int64_t numSubgroups = workgroupSize[0] / subgroupSize;
79+
int64_t totalCopySize = copyMemRefType.getNumElements();
80+
int64_t totalCopySizePerSubgroup = totalCopySize / numSubgroups;
81+
int64_t numCopiesPerThread =
82+
(totalCopySizePerSubgroup / elementsPerCopy) / subgroupSize;
83+
int64_t residualElements =
84+
totalCopySizePerSubgroup % (subgroupSize * elementsPerCopy);
85+
86+
LDBG("-- elementsPerCopy: " << elementsPerCopy);
87+
LDBG("-- workgroupSize: " << workgroupSize[0]);
88+
LDBG("-- numSubgroups: " << numSubgroups);
89+
LDBG("-- totalCopySize: " << totalCopySize);
90+
LDBG("-- totalCopySizePerSubgroup: " << totalCopySizePerSubgroup);
91+
LDBG("-- numCopiesPerThread: " << numCopiesPerThread);
92+
LDBG("-- residualElements: " << residualElements);
93+
94+
if (residualElements != 0) {
95+
return rewriter.notifyMatchFailure(
96+
copy, "Cannot proceed: cannot handle copying residual elements.");
97+
}
98+
99+
Value subgroupId = rewriter.create<gpu::SubgroupIdOp>(loc, nullptr);
100+
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, nullptr);
101+
102+
auto sourceType = cast<MemRefType>(copy.getOperand(0).getType());
103+
auto localType = cast<MemRefType>(copy.getOutputs().front().getType());
104+
105+
auto getGlobalGatherIndex = [&](Value sgIdVal, Value lIdVal,
106+
Value indVar) -> Value {
107+
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
108+
return rewriter.create<affine::AffineLinearizeIndexOp>(
109+
loc, ValueRange{sgIdVal, indVar, lIdVal, zero},
110+
ArrayRef<int64_t>{numSubgroups, numCopiesPerThread, subgroupSize,
111+
elementsPerCopy},
112+
/*disjoint=*/true);
113+
};
114+
115+
auto getSubgroupStoreBaseIndex = [&](Value sgIdVal, Value indVar) -> Value {
116+
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
117+
return getGlobalGatherIndex(sgIdVal, zero, indVar);
118+
};
119+
120+
// Build a for loop skeleton:
121+
scf::ForOp forOp = rewriter.create<scf::ForOp>(
122+
loc, /*lb=*/rewriter.create<arith::ConstantIndexOp>(loc, 0),
123+
/*ub=*/rewriter.create<arith::ConstantIndexOp>(loc, numCopiesPerThread),
124+
/*steps=*/rewriter.create<arith::ConstantIndexOp>(loc, 1));
125+
126+
auto delinearizeIndex = [&](Value index, ArrayRef<int64_t> shape) {
127+
return rewriter.create<affine::AffineDelinearizeIndexOp>(loc, index, shape)
128+
.getMultiIndex();
129+
};
130+
131+
// For loop body:
132+
{
133+
OpBuilder::InsertionGuard guard(rewriter);
134+
rewriter.setInsertionPointToStart(forOp.getBody());
135+
auto inductionVar = forOp.getInductionVar();
136+
Value linearizedGatherIndices =
137+
getGlobalGatherIndex(subgroupId, laneId, inductionVar);
138+
ValueRange delinearizedGlobalIndices =
139+
delinearizeIndex(linearizedGatherIndices, sourceType.getShape());
140+
Value linearizedBaseIndices =
141+
getSubgroupStoreBaseIndex(subgroupId, inductionVar);
142+
ValueRange delinearizedLocalIndices =
143+
delinearizeIndex(linearizedBaseIndices, localType.getShape());
144+
rewriter.create<IREE::GPU::GlobalLoadDMAOp>(
145+
loc, copy.getOperand(0), delinearizedGlobalIndices,
146+
copy.getOutputs()[0], delinearizedLocalIndices);
147+
}
148+
149+
// Sync at the end of the loop across threads.
150+
rewriter.replaceOpWithNewOp<gpu::BarrierOp>(copy);
151+
return success();
152+
}
153+
154+
static LogicalResult isEligibleForGlobalDMA(linalg::CopyOp copy) {
155+
// Source must be global address and target must be workgroup address.
156+
auto sourceType = cast<MemRefType>(copy.getOperand(0).getType());
157+
auto targetType = cast<MemRefType>(copy.getOutputs().front().getType());
158+
159+
if (!getLoweringConfig<IREE::GPU::UseGlobalLoadDMAAttr>(copy)) {
160+
LDBG("-- Op: " << *copy);
161+
LDBG("-- does not have `use_global_load_dma` attribute, skipping.");
162+
return failure();
163+
}
164+
165+
if (!hasGlobalMemoryAddressSpace(sourceType) ||
166+
!hasSharedMemoryAddressSpace(targetType)) {
167+
LDBG("-- Op: " << *copy);
168+
LDBG("-- incompatible source or target memory address space.");
169+
return failure();
170+
}
171+
172+
// TODO: check that the copy's target memref is not a subview: a subview
173+
// cannot guarantee contiguity of dest memory region.
174+
return success();
175+
}
176+
177+
struct LowerToDMAPattern : public OpRewritePattern<linalg::CopyOp> {
178+
LowerToDMAPattern(MLIRContext *context, ArrayRef<int64_t> workgroupSize,
179+
int64_t subgroupSize)
180+
: OpRewritePattern<linalg::CopyOp>(context), workgroupSize(workgroupSize),
181+
subgroupSize(subgroupSize) {}
182+
183+
LogicalResult matchAndRewrite(linalg::CopyOp copy,
184+
PatternRewriter &rewriter) const override {
185+
if (failed(isEligibleForGlobalDMA(copy))) {
186+
return failure();
187+
}
188+
return distributeLinalgCopyToThreads(rewriter, copy, workgroupSize,
189+
subgroupSize);
190+
}
191+
192+
private:
193+
ArrayRef<int64_t> workgroupSize;
194+
int64_t subgroupSize;
195+
};
196+
197+
namespace {
198+
struct GPULowerToGlobalLoadsPass final
199+
: impl::GPULowerToGlobalLoadsPassBase<GPULowerToGlobalLoadsPass> {
200+
201+
void runOnOperation() override {
202+
MLIRContext *context = &getContext();
203+
auto funcOp = getOperation();
204+
205+
std::optional<SmallVector<int64_t>> workgroupSize =
206+
mlir::iree_compiler::getWorkgroupSize(funcOp);
207+
if (!workgroupSize) {
208+
funcOp.emitOpError(
209+
"unimplemented: Distribution with dynamic workgroup size.");
210+
return signalPassFailure();
211+
}
212+
auto subgroupSize = mlir::iree_compiler::getSubgroupSize(funcOp);
213+
if (!subgroupSize) {
214+
funcOp.emitOpError(
215+
"unimplemented: Distribution with dynamic subgroup size.");
216+
return signalPassFailure();
217+
}
218+
219+
RewritePatternSet patterns(context);
220+
patterns.add<LowerToDMAPattern>(context, *workgroupSize, *subgroupSize);
221+
(void)applyPatternsGreedily(funcOp, std::move(patterns));
222+
}
223+
};
224+
} // namespace
225+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/GPU/GPUPromoteMatmulOperands.cpp

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,22 @@ namespace mlir::iree_compiler {
2929

3030
namespace {
3131
/// Helper to insert copy with derived thread config.
32-
Value promoteValue(OpBuilder &builder, Location loc, Value v) {
32+
Value promoteValue(OpBuilder &builder, Location loc, Value v,
33+
bool useDirectLoad) {
3334
auto tensorType = cast<RankedTensorType>(v.getType());
3435
SmallVector<OpFoldResult> mixedSizes = tensor::getMixedSizes(builder, loc, v);
36+
3537
Value empty = builder.create<tensor::EmptyOp>(loc, mixedSizes,
3638
tensorType.getElementType());
3739
auto copy = builder.create<linalg::CopyOp>(loc, v, empty);
38-
setLoweringConfig(
39-
copy, IREE::GPU::DerivedThreadConfigAttr::get(builder.getContext()));
40+
41+
if (useDirectLoad) {
42+
setLoweringConfig(
43+
copy, IREE::GPU::UseGlobalLoadDMAAttr::get(builder.getContext()));
44+
} else {
45+
setLoweringConfig(
46+
copy, IREE::GPU::DerivedThreadConfigAttr::get(builder.getContext()));
47+
}
4048
return copy.getResult(0);
4149
}
4250

@@ -95,7 +103,8 @@ void promoteResult(OpBuilder &builder, Operation *op, Value valToMakeShared) {
95103
}
96104

97105
rewriter.setInsertionPointAfterValue(replacement);
98-
replacement = promoteValue(rewriter, loc, replacement);
106+
replacement =
107+
promoteValue(rewriter, loc, replacement, /*useDirectLoad=*/false);
99108
valueToReplace.replaceUsesWithIf(replacement, [&](OpOperand &use) {
100109
return opsToReplaceUseIn.contains(use.getOwner());
101110
});
@@ -110,7 +119,7 @@ void promoteResult(OpBuilder &builder, Operation *op, Value valToMakeShared) {
110119
///
111120
/// %empty = tensor.empty()
112121
/// %copy = linalg.copy %1 to %empty {
113-
/// lowering_config = #iree_gpu.derived_thread_config}
122+
/// lowering_config = #iree_gpu.{derived_thread_config|use_global_dma}}
114123
/// linalg.matmul ins(%0, %copy)
115124
///
116125
/// If the producer is already a tilable op, the producer is just annotated with
@@ -122,7 +131,8 @@ void promoteResult(OpBuilder &builder, Operation *op, Value valToMakeShared) {
122131
/// %copy1 = linalg.copy %2 to %out_buffer
123132
/// %copy2 = linalg.copy %copy1 to %empty {
124133
/// lowering_config = #iree_gpu.derived_thread_config}
125-
void promoteOperand(OpBuilder &builder, Operation *op, unsigned index) {
134+
void promoteOperand(OpBuilder &builder, Operation *op, unsigned index,
135+
bool useDirectLoad) {
126136
auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op);
127137
if (!dpsOp)
128138
return;
@@ -162,12 +172,15 @@ void promoteOperand(OpBuilder &builder, Operation *op, unsigned index) {
162172
return;
163173
}
164174

165-
auto replacement = promoteValue(builder, op->getLoc(), operand);
175+
auto replacement =
176+
promoteValue(builder, op->getLoc(), operand, useDirectLoad);
166177
op->setOperand(index, replacement);
167178
}
168179

169180
struct GPUPromoteMatmulOperandsPass final
170181
: impl::GPUPromoteMatmulOperandsPassBase<GPUPromoteMatmulOperandsPass> {
182+
using GPUPromoteMatmulOperandsPassBase::GPUPromoteMatmulOperandsPassBase;
183+
171184
void runOnOperation() override {
172185
FunctionOpInterface funcOp = getOperation();
173186

@@ -187,7 +200,10 @@ struct GPUPromoteMatmulOperandsPass final
187200

188201
builder.setInsertionPoint(op);
189202
for (auto operand : promotedOperands.value()) {
190-
promoteOperand(builder, op, operand);
203+
// TODO: move switch `useDirectLoad` to the promotion attr list.
204+
// Here using a command line option should be only a temporary
205+
// solution.
206+
promoteOperand(builder, op, operand, useDirectLoad);
191207
}
192208
});
193209
}

compiler/src/iree/compiler/Codegen/Common/GPU/GPUVerifyDistribution.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ struct GPUVerifyDistributionPass final
7676
continue;
7777
}
7878

79+
// Allow DMA copies.
80+
if (isa<linalg::CopyOp>(op) &&
81+
getLoweringConfig<IREE::GPU::UseGlobalLoadDMAAttr>(op)) {
82+
continue;
83+
}
84+
7985
op->emitOpError(
8086
"write affecting operations on shared resources are restricted "
8187
"to lane or thread distributed contexts.");

0 commit comments

Comments
 (0)