Skip to content

Commit abd6119

Browse files
krzysz00Groverkss
authored andcommitted
[GPU] Use affine.linearize_index (and delinearize_index) where possible (iree-org#19122)
There have been issues with the composition of affine maps being too general and loosing important information, like the fact that affine_map<(s0 + s1 * 32 + ... - (s0 floorDiv 16) * 16)> realy should be affine_map<(s0 mod 16 + s1 * 32 + ...)>, and other issues with the ultimate IR that block low-level arithmetic optimizations. The affine.delinearize_index operation represents the div/mod chains needed to break a flat index into its component parts. A recently added affine.linearize_index operation is its inverse - combining multiple indices into a flat 1D value. Another advantage to linearize/delinearize is simpler upstream canonicalizations and lead to more streamlined generated code. This PR updates the vector distribution code and other GPU-related code that I could find to 1. Use affine.linearize_index to construct flat thread IDs 2. Use affine.delinearize_index in places where there was a floorDiv/mod chain. 3. Plumb the subgroup size through the transfer_read and transfer_write distribution patterns to enable better reasoning about when you do/don't need to take a mod of the lane ID
1 parent 6e886be commit abd6119

15 files changed

+224
-265
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/GPU/Transforms/Passes.h"
1818
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
1919
#include "mlir/Dialect/SCF/IR/SCF.h"
20+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2021

2122
namespace mlir::iree_compiler {
2223

@@ -87,9 +88,16 @@ LogicalResult resolveGPUMappedForallOp(RewriterBase &rewriter,
8788
assert(!(hasThreadMapping && hasWarpMapping));
8889
Value flatId = linearThreadId;
8990
if (hasWarpMapping) {
90-
OpFoldResult subgroupSizeVal = rewriter.getIndexAttr(subgroupSize);
91-
flatId = affine::makeComposedAffineApply(rewriter, loc, d0.floorDiv(d1),
92-
{flatId, subgroupSizeVal});
91+
if (flatWorkgroupSize % subgroupSize != 0) {
92+
return forallOp->emitOpError(
93+
"found warp mapped forall with non-multiple workgroup size");
94+
}
95+
flatId = rewriter
96+
.create<affine::AffineDelinearizeIndexOp>(
97+
loc, flatId,
98+
ArrayRef<int64_t>{flatWorkgroupSize / subgroupSize,
99+
subgroupSize})
100+
.getResult(0);
93101
}
94102

95103
SmallVector<Value> delinSizes;
@@ -190,23 +198,18 @@ void GPUDistributeForallPass::runOnOperation() {
190198
return signalPassFailure();
191199
}
192200

193-
AffineExpr x, y, z;
194-
bindSymbols(funcOp.getContext(), x, y, z);
195-
// Compute the linearized thread id.
196-
AffineExpr linearId =
197-
x + workgroupSize[0] * y + workgroupSize[1] * workgroupSize[0] * z;
198-
199201
rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
200-
SmallVector<OpFoldResult> threadGrid = {
201-
rewriter.createOrFold<gpu::ThreadIdOp>(funcOp.getLoc(),
202-
gpu::Dimension::x),
203-
rewriter.createOrFold<gpu::ThreadIdOp>(funcOp.getLoc(),
204-
gpu::Dimension::y),
205-
rewriter.createOrFold<gpu::ThreadIdOp>(funcOp.getLoc(),
206-
gpu::Dimension::z)};
207-
208-
Value linearThreadIdVal = affine::makeComposedAffineApply(
209-
rewriter, funcOp.getLoc(), linearId, threadGrid);
202+
SmallVector<Value> threadGrid = {rewriter.createOrFold<gpu::ThreadIdOp>(
203+
funcOp.getLoc(), gpu::Dimension::z),
204+
rewriter.createOrFold<gpu::ThreadIdOp>(
205+
funcOp.getLoc(), gpu::Dimension::y),
206+
rewriter.createOrFold<gpu::ThreadIdOp>(
207+
funcOp.getLoc(), gpu::Dimension::x)};
208+
SmallVector<int64_t> threadGridBasis = {workgroupSize[2], workgroupSize[1],
209+
workgroupSize[0]};
210+
211+
Value linearThreadIdVal = rewriter.create<affine::AffineLinearizeIndexOp>(
212+
funcOp.getLoc(), threadGrid, threadGridBasis, /*disjoint=*/true);
210213
for (auto forall : forallOps) {
211214
rewriter.setInsertionPoint(forall);
212215
if (failed(resolveGPUMappedForallOp(rewriter, forall, linearThreadIdVal,

compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -189,30 +189,29 @@ SmallVector<linalg::ProcInfo> getIds(OpBuilder &b, Location loc,
189189
ArrayRef<Range> parallelLoopRanges,
190190
Value flatThreadId) {
191191
SmallVector<linalg::ProcInfo> infos;
192-
Value id = flatThreadId;
193-
AffineExpr d0 = b.getAffineDimExpr(0);
194-
for (Range r : llvm::reverse(parallelLoopRanges)) {
195-
linalg::ProcInfo info;
192+
SmallVector<int64_t> delinSizes;
193+
for (Range r : parallelLoopRanges) {
196194
auto offset = dyn_cast<Attribute>(r.offset);
197195
auto stride = dyn_cast<Attribute>(r.stride);
198196
auto size = dyn_cast<Attribute>(r.size);
199197
assert(offset && stride && size);
200198
int64_t numThreadsDim = (llvm::cast<IntegerAttr>(size).getInt() -
201199
llvm::cast<IntegerAttr>(offset).getInt()) /
202200
llvm::cast<IntegerAttr>(stride).getInt();
203-
Value dimId = id;
204-
if (infos.size() != parallelLoopRanges.size() - 1)
205-
dimId =
206-
affine::makeComposedAffineApply(b, loc, d0 % numThreadsDim, {dimId});
201+
delinSizes.push_back(numThreadsDim);
202+
}
203+
ValueRange dims =
204+
b.create<affine::AffineDelinearizeIndexOp>(loc, flatThreadId, delinSizes)
205+
.getResults();
206+
207+
for (auto [dimId, numThreadsDim] : llvm::zip_equal(dims, delinSizes)) {
208+
linalg::ProcInfo info;
207209
info.procId = dimId;
208210
info.nprocs = b.create<arith::ConstantIndexOp>(loc, numThreadsDim);
209211
info.distributionMethod =
210212
linalg::DistributionMethod::CyclicNumProcsEqNumIters;
211213
infos.push_back(info);
212-
id = affine::makeComposedAffineApply(b, loc, d0.floorDiv(numThreadsDim),
213-
{id});
214214
}
215-
std::reverse(infos.begin(), infos.end());
216215
return infos;
217216
}
218217

@@ -288,19 +287,16 @@ static Value createFlatId(mlir::FunctionOpInterface funcOp,
288287
ArrayRef<int64_t> workgroupSize) {
289288
OpBuilder b(funcOp.getFunctionBody());
290289
Type indexType = b.getIndexType();
291-
AffineExpr d0 = getAffineDimExpr(0, b.getContext());
292-
AffineExpr d1 = getAffineDimExpr(1, b.getContext());
293-
AffineExpr d2 = getAffineDimExpr(2, b.getContext());
294290
Value threadX =
295291
b.create<gpu::ThreadIdOp>(funcOp.getLoc(), indexType, gpu::Dimension::x);
296292
Value threadY =
297293
b.create<gpu::ThreadIdOp>(funcOp.getLoc(), indexType, gpu::Dimension::y);
298294
Value threadZ =
299295
b.create<gpu::ThreadIdOp>(funcOp.getLoc(), indexType, gpu::Dimension::z);
300-
Value flatThreadId = affine::makeComposedAffineApply(
301-
b, funcOp.getLoc(),
302-
d0 + workgroupSize[0] * d1 + (workgroupSize[0] * workgroupSize[1]) * d2,
303-
{threadX, threadY, threadZ});
296+
Value flatThreadId = b.create<affine::AffineLinearizeIndexOp>(
297+
funcOp.getLoc(), ValueRange{threadZ, threadY, threadX},
298+
ArrayRef<int64_t>{workgroupSize[2], workgroupSize[1], workgroupSize[0]},
299+
/*disjoint=*/true);
304300
return flatThreadId;
305301
}
306302

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@ func.func @distribute_thread_forall(%out : memref<?xi32>)
1515
// CHECK-LABEL: func @distribute_thread_forall
1616
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
1717
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
18-
// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z
18+
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
1919
// CHECK: scf.for %[[I:.+]] = %c0 to %c1024 step %c128 {
20-
// CHECK: %[[LINID:.+]] = affine.apply
21-
// CHECK-SAME: affine_map<(d0)[s0, s1, s2] -> (d0 + s0 + s1 * 64 + s2 * 128)>(%[[I]])
22-
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
20+
// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[TFLAT]]]
2321
// CHECK: memref.store {{.*}}[%[[LINID]]]
2422

2523
// -----
@@ -38,11 +36,10 @@ func.func @distribute_warp_forall(%out : memref<?xi32>)
3836
// CHECK-LABEL: func @distribute_warp_forall
3937
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
4038
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
41-
// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z
39+
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
40+
// CHECK: %[[WARPSPLIT:.+]]:2 = affine.delinearize_index %[[TFLAT]] into (4, 32)
4241
// CHECK: scf.for %[[I:.+]] = %c0 to %c32 step %c4 {
43-
// CHECK: %[[LINID:.+]] = affine.apply
44-
// CHECK-SAME: affine_map<(d0)[s0, s1, s2] -> (d0 + s1 * 2 + s2 * 4 + s0 floordiv 32)>(%[[I]])
45-
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
42+
// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[WARPSPLIT]]#0]
4643
// CHECK: memref.store {{.*}}[%[[LINID]]]
4744

4845
// -----
@@ -78,11 +75,7 @@ func.func @distribute_thread_forall_drop_for_loop(%out : memref<?xi32>)
7875
// CHECK-LABEL: func @distribute_thread_forall_drop_for_loop
7976
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
8077
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
81-
// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z
82-
// CHECK-NOT: scf.for
83-
// CHECK: %[[LINID:.+]] = affine.apply
84-
// CHECK-SAME: affine_map<()[s0, s1, s2] -> (s0 + s1 * 64 + s2 * 128)>
85-
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
78+
// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
8679
// CHECK: memref.store {{.*}}[%[[LINID]]]
8780

8881
// -----
@@ -99,13 +92,32 @@ func.func @distribute_thread_forall_single_thread(%out : memref<?xi32>)
9992
}
10093

10194
// CHECK-LABEL: func @distribute_thread_forall_single_thread
95+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
10296
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
10397
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
104-
// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z
105-
// CHECK: %[[LINID:.+]] = affine.apply
106-
// CHECK-SAME: affine_map<()[s0, s1, s2] -> (s0 + s1 * 64 + s2 * 128)>
107-
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
108-
// CHECK: scf.for %[[I:.+]] = %[[LINID]] to %c1 step %c128 {
98+
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
99+
// CHECK: scf.for %[[I:.+]] = %[[TFLAT]] to %c1 step %c128 {
100+
// CHECK: memref.store {{.*}}[%[[I]]]
101+
102+
// -----
103+
104+
#translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 2, 1] subgroup_size = 32>
105+
106+
func.func @distribute_thread_forall_overhang(%out : memref<?xi32>)
107+
attributes {translation_info = #translation_info} {
108+
%c0 = arith.constant 0 : i32
109+
scf.forall (%arg0) in (513) {
110+
memref.store %c0, %out[%arg0] : memref<?xi32>
111+
} {mapping = [#gpu.thread<linear_dim_0>]}
112+
return
113+
}
114+
115+
// CHECK-LABEL: func @distribute_thread_forall_overhang
116+
// CHECK-DAG: %[[C513:.+]] = arith.constant 513 : index
117+
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
118+
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
119+
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
120+
// CHECK: scf.for %[[I:.+]] = %[[TFLAT]] to %[[C513]] step %c128 {
109121
// CHECK: memref.store {{.*}}[%[[I]]]
110122

111123
// -----
@@ -124,11 +136,9 @@ func.func @distribute_thread_forall_multi_dim(%out : memref<?x?x?xi32>)
124136
// CHECK-LABEL: func @distribute_thread_forall_multi_dim
125137
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
126138
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
127-
// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z
139+
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
128140
// CHECK: scf.for %[[I:.+]] = %c0 to %c512 step %c128 {
129-
// CHECK: %[[LINID:.+]] = affine.apply
130-
// CHECK-SAME: affine_map<(d0)[s0, s1, s2] -> (d0 + s0 + s1 * 64 + s2 * 128)>(%[[I]])
131-
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
141+
// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[TFLAT]]]
132142
// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LINID]] into (16, 8, 4) : index
133143
// CHECK: memref.store {{.*}}[%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2]
134144

@@ -147,10 +157,5 @@ func.func @distribute_thread_forall_small_workgroup(%out : memref<?xi32>)
147157
}
148158

149159
// CHECK-LABEL: func @distribute_thread_forall_small_workgroup
150-
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
151-
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
152-
// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z
153-
// CHECK: %[[LINID:.+]] = affine.apply
154-
// CHECK-SAME: affine_map<()[s0, s1, s2] -> (s0 + s1 * 7 + s2 * 7)>
155-
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
156-
// CHECK: memref.store {{.*}}[%[[LINID]]]
160+
// CHECK: %[[TX:.+]] = gpu.thread_id x
161+
// CHECK: memref.store {{.*}}[%[[TX]]]

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_shared_memory.mlir

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,37 +49,32 @@ module {
4949
}
5050
}
5151

52-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4)>
53-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16)>
54-
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4 + 32)>
55-
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 128)>
56-
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 128 + 128)>
57-
// CHECK-DAG: #[[$MAP5:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 4 + s1 * 128 + s2 * 512)>
52+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 4)>
53+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 + 32)>
54+
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 128)>
5855
// CHECK-LABEL: @shared_mem_cpy(
5956

6057
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
6158
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
6259
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
6360
// CHECK-DAG: %[[TX:.*]] = gpu.thread_id x
6461
// CHECK-DAG: %[[TY:.*]] = gpu.thread_id y
65-
// CHECK-DAG: %[[TZ:.*]] = gpu.thread_id z
66-
67-
// CHECK-DAG: %[[Y0:.*]] = affine.apply #[[$MAP0]]()[%[[TX]], %[[TY]], %[[TZ]]]
68-
// CHECK-DAG: %[[X0:.*]] = affine.apply #[[$MAP1]]()[%[[TX]]]
69-
// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[Y0]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32>, vector<1x4xf32>
70-
// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[Y0]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3>
71-
// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP2]]()[%[[TX]], %[[TY]], %[[TZ]]]
62+
// CHECK: %[[TFLAT:.*]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (4, 32)
63+
// CHECK: %[[YX:.*]]:2 = affine.delinearize_index %[[TFLAT]] into (32, 4)
64+
// CHECK: %[[X0:.*]] = affine.apply #[[$MAP0]]()[%[[YX]]#1]
65+
// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[YX]]#0, %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32>, vector<1x4xf32>
66+
// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[YX]]#0, %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3>
67+
// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP1]]()[%[[YX]]#0]
7268
// CHECK: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[Y1]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32>, vector<1x4xf32>
7369
// CHECK: vector.transfer_write %[[R1]], %{{.*}}[%[[Y1]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3>
7470

75-
// CHECK: %[[Y1:.*]] = affine.apply #[[$MAP3]]()[%[[TX]], %[[TY]], %[[TZ]]]
76-
// CHECK: %[[R2:.*]] = vector.transfer_read %{{.*}}[%[[Y1]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<256x4xf32>, vector<1x4xf32>
77-
// CHECK: vector.transfer_write %[[R2]], %{{.*}}[%[[Y1]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3>
78-
// CHECK: %[[Y2:.*]] = affine.apply #[[$MAP4]]()[%[[TX]], %[[TY]], %[[TZ]]]
71+
// CHECK: %[[R2:.*]] = vector.transfer_read %{{.*}}[%[[TFLAT]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<256x4xf32>, vector<1x4xf32>
72+
// CHECK: vector.transfer_write %[[R2]], %{{.*}}[%[[TFLAT]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3>
73+
// CHECK: %[[Y2:.*]] = affine.apply #[[$MAP2]]()[%[[TFLAT]]]
7974
// CHECK: %[[R3:.*]] = vector.transfer_read %{{.*}}[%[[Y2]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<256x4xf32>, vector<1x4xf32>
8075
// CHECK: vector.transfer_write %[[R3]], %{{.*}}[%[[Y2]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3>
8176

82-
// CHECK: %[[X1:.*]] = affine.apply #[[$MAP5]]()[%[[TX]], %[[TY]], %[[TZ]]]
77+
// CHECK: %[[X1:.*]] = affine.apply #[[$MAP0]]()[%[[TFLAT]]]
8378
// CHECK: %[[R4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[X1]]], %{{.*}} {in_bounds = [true, true]} : memref<3x512xf32>, vector<1x4xf32>
8479
// CHECK: vector.transfer_write %[[R4]], %{{.*}}[%[[C0]], %[[X1]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3>
8580
// CHECK: %[[R5:.*]] = vector.transfer_read %{{.*}}[%[[C1]], %[[X1]]], %{{.*}} {in_bounds = [true, true]} : memref<3x512xf32>, vector<1x4xf32>

compiler/src/iree/compiler/Codegen/Common/GPU/test/transform_gpu_distribute_shared_memory.mlir

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,19 @@ module attributes {transform.with_named_sequence} {
4646
transform.yield
4747
}
4848
}
49-
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4)>
50-
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16)>
51-
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4 + 32)>
49+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 4)>
50+
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 + 32)>
5251
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d0, d1)>
5352
// CHECK-LABEL: @shared_mem_cpy(
5453
// CHECK-DAG: %[[TX:.*]] = gpu.thread_id x
5554
// CHECK-DAG: %[[TY:.*]] = gpu.thread_id y
56-
// CHECK-DAG: %[[TZ:.*]] = gpu.thread_id z
5755

58-
// CHECK-DAG: %[[Y0:.*]] = affine.apply #[[$MAP0]]()[%[[TX]], %[[TY]], %[[TZ]]]
59-
// CHECK-DAG: %[[X0:.*]] = affine.apply #[[$MAP1]]()[%[[TX]]]
60-
// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[Y0]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32, #hal.descriptor_type<storage_buffer>>, vector<1x4xf32>
61-
// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[Y0]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, #gpu.address_space<workgroup>>
62-
// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP2]]()[%[[TX]], %[[TY]], %[[TZ]]]
56+
// CHECK-DAG: %[[TFLAT:.*]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (4, 32)
57+
// CHECK-DAG: %[[YX:.*]]:2 = affine.delinearize_index %[[TFLAT]] into (32, 4)
58+
// CHECK-DAG: %[[X0:.*]] = affine.apply #[[$MAP0]]()[%[[YX]]#1]
59+
// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[YX]]#0, %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32, #hal.descriptor_type<storage_buffer>>, vector<1x4xf32>
60+
// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[YX]]#0, %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, #gpu.address_space<workgroup>>
61+
// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP1]]()[%[[YX]]#0]
6362
// CHECK: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[Y1]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32, #hal.descriptor_type<storage_buffer>>, vector<1x4xf32>
6463
// CHECK: vector.transfer_write %[[R1]], %{{.*}}[%[[Y1]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, #gpu.address_space<workgroup>>
6564
// CHECK: linalg.generic

compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,16 +1113,15 @@ transform_dialect::TestGpuVectorDistribution::applyToOne(
11131113
rewriter.setInsertionPointToStart(&target.getFunctionBody().front());
11141114
// This is a test op so we unsafely use thread_id x as the lane ID. In
11151115
// general this should linearize the thread IDs based on the workgroup size
1116-
// and divide by the subgroup size. i.e.
1116+
// and take the modulo by the subgroup size. i.e.
11171117
//
1118-
// lane_id = (tid_x + tid_y * dim_x + tid_z * dim_y * dim_x) / subgroup_size;
1118+
// lane_id = (tid_x + tid_y * dim_x + tid_z * dim_y * dim_x) % subgroup_size;
11191119
Value laneId =
11201120
rewriter.create<gpu::ThreadIdOp>(target.getLoc(), gpu::Dimension::x);
1121+
int64_t subgroupSize = getSubgroupSize();
11211122

11221123
populateGPUDistributionPatterns(patterns);
1123-
// For testing we use subgroup size = 64.
1124-
populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId,
1125-
/*subgroupSize=*/64);
1124+
populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId, subgroupSize);
11261125
populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns);
11271126
if (failed(distributeVectorOps(target, patterns, options))) {
11281127
return emitDefaultDefiniteFailure(target);

compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,8 @@ def TestGpuVectorDistribution :
631631
}];
632632

633633
let arguments = (ins TransformHandleTypeInterface:$target,
634-
DefaultValuedOptionalAttr<BoolAttr, "false">:$experimental);
634+
DefaultValuedOptionalAttr<BoolAttr, "false">:$experimental,
635+
DefaultValuedOptionalAttr<I64Attr, "64">:$subgroup_size);
635636
let results = (outs);
636637

637638
let assemblyFormat = [{ $target attr-dict `:` type($target)}];

0 commit comments

Comments
 (0)