Skip to content

Commit 1e2abb8

Browse files
authored
Revert "Three reverts to undo transfer_write deduplication and return… (#22521)
… to previous state (#22392)" This reverts commit 4a716e2. The underlying issue was fixed by llvm/llvm-project#165764 . Thanks to @newling for figuring out this tricky issue. Fixes: #22397 ci-extra: test_torch
1 parent aa1773b commit 1e2abb8

14 files changed

+330
-35
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
1111
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
1212
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
13+
#include "iree/compiler/Utils/Indexing.h"
1314
#include "iree/compiler/Utils/Permutation.h"
1415
#include "llvm/ADT/ArrayRef.h"
16+
#include "llvm/ADT/STLExtras.h"
1517
#include "llvm/ADT/SmallVector.h"
1618
#include "llvm/Support/FormatVariadic.h"
1719
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -422,9 +424,70 @@ struct DistributeTransferWrite final
422424
using OpDistributionPattern::OpDistributionPattern;
423425

424426
DistributeTransferWrite(MLIRContext *context, Value threadId,
425-
int64_t subgroupSize)
427+
int64_t subgroupSize, ArrayRef<int64_t> workgroupSize)
426428
: OpDistributionPattern(context), threadId(threadId),
427-
subgroupSize(subgroupSize) {}
429+
subgroupSize(subgroupSize) {
430+
431+
// The number of threads in the workgroup is the product of the dimensions
432+
// of workgroupSize, unless workgroupSize is empty.
433+
if (!workgroupSize.empty()) {
434+
numThreadsInWorkgroup = llvm::product_of(workgroupSize);
435+
}
436+
}
437+
438+
/// Compute a boolean in SIMT semantics that is true for the first virtual
439+
/// lane(thread) id (vtid) and virtual subgroup id (vsid) carrying broadcasted
440+
/// data.
441+
///
442+
/// We do this by computing a basis for vtid and vsid computation, and adding
443+
/// a check for basis elements that are not used (i.e. they are duplicated)
444+
/// to be zero.
445+
FailureOr<Value> getNoOverlapCondition(OpBuilder &b, Location loc,
446+
NestedLayoutAttr layout) const {
447+
ArrayRef<int64_t> threadTile = layout.getThreadTile();
448+
ArrayRef<int64_t> threadStrides = layout.getThreadStrides();
449+
ArrayRef<int64_t> subgroupTile = layout.getSubgroupTile();
450+
// Multiply the subgroup strides by subgroup_size to reflect thread id
451+
// relative strides.
452+
auto subgroupStrides =
453+
llvm::map_to_vector(layout.getSubgroupStrides(),
454+
[&](int64_t x) { return x * subgroupSize; });
455+
auto concatTiles =
456+
llvm::to_vector(llvm::concat<const int64_t>(subgroupTile, threadTile));
457+
auto concatStrides = llvm::to_vector(
458+
llvm::concat<const int64_t>(subgroupStrides, threadStrides));
459+
SmallVector<int64_t> basis;
460+
SmallVector<size_t> dimToResult;
461+
if (failed(basisFromSizesStrides(concatTiles, concatStrides, basis,
462+
dimToResult))) {
463+
return failure();
464+
}
465+
// Make the outer bound numThreadsInWorkgroup / prod(basis) to remove
466+
// redundant checks.
467+
if (numThreadsInWorkgroup.has_value()) {
468+
int64_t outerBound =
469+
numThreadsInWorkgroup.value() / llvm::product_of(basis);
470+
basis.insert(basis.begin(), outerBound);
471+
}
472+
// Create a delinearize operation and check that all results not present in
473+
// dimToResult are 0.
474+
SmallVector<Value> delinearized;
475+
b.createOrFold<affine::AffineDelinearizeIndexOp>(
476+
delinearized, loc, threadId, basis,
477+
/*hasOuterbound=*/numThreadsInWorkgroup.has_value());
478+
// Get all results which are not in dimToResult and check they are 0.
479+
Value condition = arith::ConstantOp::create(b, loc, b.getBoolAttr(true));
480+
for (auto [idx, result] : llvm::enumerate(delinearized)) {
481+
if (llvm::is_contained(dimToResult, idx)) {
482+
continue;
483+
}
484+
Value isZero = b.createOrFold<arith::CmpIOp>(
485+
loc, arith::CmpIPredicate::eq, result,
486+
arith::ConstantIndexOp::create(b, loc, 0));
487+
condition = b.createOrFold<arith::AndIOp>(loc, condition, isZero);
488+
}
489+
return condition;
490+
}
428491

429492
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
430493
DistributionSignature &signature,
@@ -456,7 +519,6 @@ struct DistributeTransferWrite final
456519
SmallVector<int64_t> distShape = vectorLayout.getDistributedShape();
457520
SmallVector<int64_t> tileShape = getElementVectorTileShape(vectorLayout);
458521
int64_t rank = vectorLayout.getRank();
459-
460522
SmallVector<Value> warpIndices, threadIndices;
461523
if (failed(populateWarpAndThreadIndices(rewriter, threadId, subgroupSize,
462524
vectorLayout, warpIndices,
@@ -465,6 +527,18 @@ struct DistributeTransferWrite final
465527
writeOp, "warp or thread tiles have overlapping strides");
466528
}
467529

530+
// If the distribution results in threads writing to the same address, guard
531+
// with an scf.if to ensure only one thread writes per duplication group.
532+
Location loc = writeOp.getLoc();
533+
FailureOr<Value> doWrite =
534+
getNoOverlapCondition(rewriter, loc, vectorLayout);
535+
if (failed(doWrite)) {
536+
return rewriter.notifyMatchFailure(
537+
writeOp, "failed to compute no-overlap condition");
538+
}
539+
auto ifOp = scf::IfOp::create(rewriter, loc, doWrite.value());
540+
rewriter.setInsertionPoint(ifOp.thenYield());
541+
468542
Value distributedVector =
469543
getDistributed(rewriter, writeOp.getValueToStore(), vectorLayout);
470544

@@ -485,7 +559,6 @@ struct DistributeTransferWrite final
485559
SmallVector<Value> slicedIndices = getTransferIndicesFromNestedLayout(
486560
rewriter, indices, offsets, vectorLayout, permMap, warpIndices,
487561
threadIndices);
488-
489562
// Extract the "element vector" from the inner most dimensions. All outer
490563
// dimensions are either unrolled or distributed such that this is a
491564
// contiguous slice.
@@ -516,6 +589,7 @@ struct DistributeTransferWrite final
516589

517590
Value threadId;
518591
int64_t subgroupSize;
592+
std::optional<int64_t> numThreadsInWorkgroup = std::nullopt;
519593
};
520594

521595
/// Pattern to distribute `vector.transfer_gather` ops with nested layouts.
@@ -2127,13 +2201,14 @@ struct DistributeConstantMask final
21272201

21282202
} // namespace
21292203

2130-
void populateGPUDistributeNestedLayoutAttrPatterns(RewritePatternSet &patterns,
2131-
Value threadId,
2132-
int64_t subgroupSize,
2133-
int64_t maxBitsPerShuffle) {
2134-
patterns.add<DistributeTransferRead, DistributeTransferWrite,
2135-
DistributeTransferGather, DistributeMapScatter>(
2136-
patterns.getContext(), threadId, subgroupSize);
2204+
void populateGPUDistributeNestedLayoutAttrPatterns(
2205+
RewritePatternSet &patterns, Value threadId, int64_t subgroupSize,
2206+
ArrayRef<int64_t> workgroupSize, int64_t maxBitsPerShuffle) {
2207+
patterns.add<DistributeTransferRead, DistributeTransferGather,
2208+
DistributeMapScatter>(patterns.getContext(), threadId,
2209+
subgroupSize);
2210+
patterns.add<DistributeTransferWrite>(patterns.getContext(), threadId,
2211+
subgroupSize, workgroupSize);
21372212
patterns.add<DistributeBroadcast, DistributeTranspose>(patterns.getContext());
21382213
patterns.add<DistributeMultiReduction>(patterns.getContext(), subgroupSize,
21392214
maxBitsPerShuffle);

compiler/src/iree/compiler/Codegen/Common/GPU/GPUPatterns.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ void populateGPUDistributionPatterns(RewritePatternSet &patterns);
3333

3434
void populateGPUDistributeNestedLayoutAttrPatterns(
3535
RewritePatternSet &patterns, Value threadId, int64_t subgroupSize,
36-
int64_t maxBitsPerShuffle = 32);
36+
ArrayRef<int64_t> workgroupSize, int64_t maxBitsPerShuffle = 32);
3737

3838
// Adds patterns that distributes vector.contract ops with nested layout
3939
// annotations to amdgpu.mfma ops.

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,3 +1409,142 @@ builtin.module attributes { transform.with_named_sequence } {
14091409
// CHECK-DAG: %[[DISTRIBUTED_IDX1:.+]] = arith.addi %[[IDX1]], %[[C8]]
14101410
// CHECK: iree_linalg_ext.yield %[[DISTRIBUTED_IDX0]], %[[DISTRIBUTED_IDX1]]
14111411
// CHECK: : vector<1x8xf16> into memref<64x64xf16>
1412+
1413+
// -----
1414+
1415+
// Check that only the first lane of the first subgroup writes when the threads
1416+
// are completely undistributed (all threads write to same address).
1417+
// CHECK-LABEL: @undistributed_write
1418+
func.func @undistributed_write(%out: memref<f32, #amdgpu.address_space<fat_raw_buffer>>, %v: vector<f32>) {
1419+
// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
1420+
// CHECK-DAG: %[[TID:.*]] = gpu.thread_id x
1421+
// CHECK-DAG: %[[COND:.+]] = arith.cmpi eq, %[[TID]], %[[ZERO]] : index
1422+
// CHECK-NEXT: scf.if %[[COND]] {
1423+
// CHECK: vector.transfer_write
1424+
// CHECK-NEXT: }
1425+
vector.transfer_write %v, %out[] : vector<f32>, memref<f32, #amdgpu.address_space<fat_raw_buffer>>
1426+
return
1427+
}
1428+
1429+
builtin.module attributes { transform.with_named_sequence } {
1430+
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
1431+
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
1432+
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
1433+
transform.yield
1434+
}
1435+
}
1436+
1437+
// -----
1438+
1439+
#layout_row_major = #iree_vector_ext.nested_layout<
1440+
subgroup_tile = [4, 1],
1441+
batch_tile = [1, 1],
1442+
outer_tile = [1, 1],
1443+
thread_tile = [2, 8],
1444+
element_tile = [1, 2],
1445+
subgroup_strides = [1, 1],
1446+
thread_strides = [32, 1]
1447+
>
1448+
1449+
// subgroup_size = 64 (default for the transform test_gpu_vector_distribution)
1450+
// A possible thread basis for this distribution would be:
1451+
// thread_basis = [2, 4, 8] and the dimension with size "4" has data broadcasted
1452+
// across all threads (note the thread strides). This test checks if we account
1453+
// for such broadcasts when generating conditional writes.
1454+
// CHECK-LABEL: @partially_distributed_write
1455+
// CHECK-DAG: %[[TID:.+]] = gpu.thread_id x
1456+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1457+
// CHECK: %[[DELIN:.*]]:5 = affine.delinearize_index %[[TID:.+]] into (4, 2, 4, 8)
1458+
// CHECK-DAG: %[[SUBGROUP_COND:.+]] = arith.cmpi eq, %[[DELIN]]#0, %[[C0]] : index
1459+
// CHECK-DAG: %[[LANE_COND:.+]] = arith.cmpi eq, %[[DELIN]]#3, %[[C0]] : index
1460+
// CHECK: %[[COND:.+]] = arith.andi %[[SUBGROUP_COND]], %[[LANE_COND]]
1461+
// CHECK: scf.if %[[COND]] {
1462+
// CHECK: vector.transfer_write
1463+
func.func @partially_distributed_write(%out: memref<100x100xf32, #amdgpu.address_space<fat_raw_buffer>>, %v: vector<8x16xf32>) {
1464+
%w = iree_vector_ext.to_layout %v to layout(#layout_row_major) : vector<8x16xf32>
1465+
%c0 = arith.constant 0 : index
1466+
vector.transfer_write %w, %out[%c0, %c0]
1467+
{in_bounds = [true, true]}
1468+
: vector<8x16xf32>, memref<100x100xf32, #amdgpu.address_space<fat_raw_buffer>>
1469+
func.return
1470+
}
1471+
1472+
builtin.module attributes { transform.with_named_sequence } {
1473+
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
1474+
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
1475+
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
1476+
transform.yield
1477+
}
1478+
}
1479+
1480+
// -----
1481+
1482+
// In this example, threads with the same lane write to the same address. We check that only the first subgroup writes.
1483+
// i.e. threads in [0, 64) will write, threads in [64, 256) will not write.
1484+
#layout_row_major = #iree_vector_ext.nested_layout<
1485+
subgroup_tile = [1, 1],
1486+
batch_tile = [1, 1],
1487+
outer_tile = [1, 1],
1488+
thread_tile = [1, 64],
1489+
element_tile = [64, 1],
1490+
subgroup_strides = [1, 1],
1491+
thread_strides = [1, 1]
1492+
>
1493+
1494+
// CHECK-LABEL: @lanes_fully_distributed
1495+
// CHECK-DAG: %[[TID:.+]] = gpu.thread_id
1496+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1497+
// CHECK: %[[DELIN:.*]]:2 = affine.delinearize_index %[[TID:.+]] into (4, 64)
1498+
// CHECK: %[[COND:.+]] = arith.cmpi eq, %[[DELIN]]#0, %[[C0]] : index
1499+
// CHECK: scf.if %[[COND]] {
1500+
// CHECK: vector.transfer_write
1501+
func.func @lanes_fully_distributed(%out: memref<100x100xf32, #amdgpu.address_space<fat_raw_buffer>>, %v: vector<64x64xf32>) {
1502+
%w = iree_vector_ext.to_layout %v to layout(#layout_row_major) : vector<64x64xf32>
1503+
%c0 = arith.constant 0 : index
1504+
vector.transfer_write %w, %out[%c0, %c0]
1505+
{in_bounds = [true, true]}
1506+
: vector<64x64xf32>, memref<100x100xf32, #amdgpu.address_space<fat_raw_buffer>>
1507+
func.return
1508+
}
1509+
1510+
builtin.module attributes { transform.with_named_sequence } {
1511+
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
1512+
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
1513+
transform.iree.test_gpu_vector_distribution %top_level_func {workgroup_size = array<i64: 256, 1, 1>} : !transform.any_op
1514+
transform.yield
1515+
}
1516+
}
1517+
1518+
// -----
1519+
1520+
// This example is similar to the above, but now the workgroup only contains 64 threads, so no condition is needed. Confirm there is no condition.
1521+
#layout_row_major = #iree_vector_ext.nested_layout<
1522+
subgroup_tile = [1, 1],
1523+
batch_tile = [1, 1],
1524+
outer_tile = [1, 1],
1525+
thread_tile = [1, 64],
1526+
element_tile = [64, 1],
1527+
subgroup_strides = [1, 1],
1528+
thread_strides = [1, 1]
1529+
>
1530+
1531+
// CHECK-LABEL: @threads_fully_distributed
1532+
// CHECK-NOT: scf.if
1533+
// CHECK: transfer_write
1534+
// CHECK: return
1535+
func.func @threads_fully_distributed(%out: memref<100x100xf32, #amdgpu.address_space<fat_raw_buffer>>, %v: vector<64x64xf32>) {
1536+
%w = iree_vector_ext.to_layout %v to layout(#layout_row_major) : vector<64x64xf32>
1537+
%c0 = arith.constant 0 : index
1538+
vector.transfer_write %w, %out[%c0, %c0]
1539+
{in_bounds = [true, true]}
1540+
: vector<64x64xf32>, memref<100x100xf32, #amdgpu.address_space<fat_raw_buffer>>
1541+
func.return
1542+
}
1543+
1544+
builtin.module attributes { transform.with_named_sequence } {
1545+
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
1546+
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
1547+
transform.iree.test_gpu_vector_distribution %top_level_func {workgroup_size = array<i64: 64, 1, 1>} : !transform.any_op
1548+
transform.yield
1549+
}
1550+
}

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_mask.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
#nested = #iree_vector_ext.nested_layout<
44
subgroup_tile = [2, 1],
5-
batch_tile = [2, 1],
5+
batch_tile = [8, 1],
66
outer_tile = [2, 1],
7-
thread_tile = [16, 16],
7+
thread_tile = [4, 16],
88
element_tile = [2, 8],
99

1010
subgroup_strides = [1, 0],
@@ -34,13 +34,13 @@ builtin.module attributes { transform.with_named_sequence } {
3434
// CHECK-LABEL: func @masked_read_write
3535
// CHECK: %[[DIM:.+]] = memref.dim %arg0, %c0 : memref<?x128xf16>
3636
// CHECK: %[[VSID:.+]]:3 = affine.delinearize_index %thread_id_x into (2, 64) : index, index, index
37-
// CHECK: %[[VTID:.+]]:3 = affine.delinearize_index %thread_id_x into (16, 16) : index, index, index
37+
// CHECK: %[[VTID:.+]]:3 = affine.delinearize_index %thread_id_x into (4, 16) : index, index, index
3838
// CHECK: %[[LASTIDX:.+]] = arith.subi %[[DIM]], %c1 : index
39-
// CHECK: %[[PACKED_LASTIDX:.+]]:4 = affine.delinearize_index %[[LASTIDX]] into (2, 4, 16, 2) : index, index, index, index
39+
// CHECK: %[[PACKED_LASTIDX:.+]]:4 = affine.delinearize_index %[[LASTIDX]] into (2, 16, 4, 2) : index, index, index, index
4040

41-
// CHECK: %[[ETILE_VALID:.+]] = affine.linearize_index [%[[PACKED_LASTIDX]]#1, %c1] by (4, 2) : index
41+
// CHECK: %[[ETILE_VALID:.+]] = affine.linearize_index [%[[PACKED_LASTIDX]]#1, %c1] by (16, 2) : index
4242
// CHECK: %[[ETILE_VALID_BOUND:.+]] = arith.addi %[[ETILE_VALID]], %c1 : index
43-
// CHECK: %[[DISTR_LASTIDX:.+]] = affine.linearize_index [%[[PACKED_LASTIDX]]#1, %[[PACKED_LASTIDX]]#3] by (4, 2) : index
43+
// CHECK: %[[DISTR_LASTIDX:.+]] = affine.linearize_index [%[[PACKED_LASTIDX]]#1, %[[PACKED_LASTIDX]]#3] by (16, 2) : index
4444
// CHECK: %[[DISTR_BOUND:.+]] = arith.addi %[[DISTR_LASTIDX]], %c1 : index
4545

4646
// CHECK: %[[EQ_BOUND_TID:.+]] = arith.cmpi eq, %[[VTID]]#1, %[[PACKED_LASTIDX]]#2 : index
@@ -50,7 +50,7 @@ builtin.module attributes { transform.with_named_sequence } {
5050

5151
// CHECK: %[[SELTREE0:.+]] = arith.select %[[LT_BOUND_TID]], %[[ETILE_VALID_BOUND]], %c0 : index
5252
// CHECK: %[[SELTREE1:.+]] = arith.select %[[EQ_BOUND_TID]], %[[DISTR_BOUND]], %[[SELTREE0]] : index
53-
// CHECK: %[[SELTREE2:.+]] = arith.select %[[LT_BOUND_SID]], %c8, %c0 : index
53+
// CHECK: %[[SELTREE2:.+]] = arith.select %[[LT_BOUND_SID]], %c32, %c0 : index
5454
// CHECK: %[[SELTREE3:.+]] = arith.select %[[EQ_BOUND_SID]], %[[SELTREE1]], %[[SELTREE2]] : index
5555
// CHECK: %[[MASK:.+]] = vector.create_mask %[[SELTREE3]], %c8 : vector<2x8xi1>
5656

0 commit comments

Comments
 (0)