Skip to content

Commit 73c0d4f

Browse files
sebvincekuhar
andauthored
[Codegen] Add XOR-based Swizzle Attribute (iree-org#21562)
This adds a new swizzle attribute : xor_shuffle. It swizzles a element in `(row, col)` into `(row, col_swizzled)` with `col_swizzled = ((row/perPhase) % maxPhase) ^ (col )`. Definition is : `#iree_codegen.xor_shuffle<row_width, access_width, row_stride, per_phase>` By default, row_stride == row_width and per_phase=1 Example usage : ``` %alloc = memref.alloc() : memref<32768xi8, #gpu.address_space<workgroup>> %alloc_swizzle = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<128, 16>] : memref<32768xi8, #gpu.address_space<workgroup>> ``` To do reverse swizzling on GMEM loads using global_to_lds: `%val = iree_codegen.swizzle_hint %rawBuffer[#iree_codegen.xor_shuffle<128, 16, 8192>] : memref<?xi8, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>` --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent f14e6b2 commit 73c0d4f

File tree

5 files changed

+270
-3
lines changed

5 files changed

+270
-3
lines changed

compiler/src/iree/compiler/Codegen/Common/ResolveSwizzleHints.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,11 @@ static void resolveHintOp(RewriterBase &rewriter,
187187
continue;
188188
}
189189
if (auto gatherToLDSOp = dyn_cast<amdgpu::GatherToLDSOp>(user)) {
190+
// Ignore swizzleHint on Dst Operand. Gather_to_lds writes elements of a
191+
// subgroup contiguously in order of lane ID
192+
if (gatherToLDSOp.getDst() == hintOp) {
193+
continue;
194+
}
190195
int64_t accessBitWidth = cast<MemRefType>(hintOp.getOperand().getType())
191196
.getElementTypeBitWidth() *
192197
accessWidth;
@@ -201,11 +206,11 @@ static void resolveHintOp(RewriterBase &rewriter,
201206
if (accessBitWidth != transferBitWidth) {
202207
return;
203208
}
204-
205209
gatherToLDSOps.push_back(gatherToLDSOp);
206210
continue;
207211
}
208-
// Bail out if we can't rewrite all users.
212+
// Throw if we can't rewrite all users.
213+
hintOp.emitError() << "unsupported SwizzleHintOp user: " << user;
209214
return;
210215
}
211216

compiler/src/iree/compiler/Codegen/Common/test/resolve_swizzle_hints.mlir

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-resolve-swizzle-hints, canonicalize, cse))" \
1+
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-resolve-swizzle-hints, canonicalize, cse))" --verify-diagnostics \
22
// RUN: --split-input-file --mlir-print-local-scope %s | FileCheck %s
33

44
func.func @swizzle_load(%src: memref<?xf32>) -> vector<4xf32> {
@@ -70,6 +70,7 @@ func.func @swizzle_both(%src: memref<?xf32>) {
7070
// -----
7171

7272
func.func @drop_swizzle_non_access_user(%src: memref<?xf32>) -> (memref<?xf32>, vector<4xf32>) {
73+
// expected-error @+1 {{unsupported SwizzleHintOp user}}
7374
%0 = iree_codegen.swizzle_hint %src[#iree_codegen.rotate_rows<64, 4>] : memref<?xf32>
7475
%offset = arith.constant 68 : index
7576
%1 = vector.load %0[%offset] : memref<?xf32>, vector<4xf32>
@@ -243,3 +244,82 @@ func.func @swizzle_gather_to_lds_scalar(%src: memref<?xf32>, %offset: index) {
243244
// CHECK: %[[IELEM:.+]] = arith.muli %[[I]], %[[ROW_WIDTH]] : index
244245
// CHECK: %[[SWOFF:.+]] = arith.addi %[[ROTATEJ]], %[[IELEM]] : index
245246
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SWOFF]]], %[[LDS]][%[[DSTOFFSET]]]
247+
248+
249+
func.func @swizzle_load_xor(%src: memref<?xi8>) -> vector<16xi8> {
250+
%0 = iree_codegen.swizzle_hint %src[#iree_codegen.xor_shuffle<128, 16>] : memref<?xi8>
251+
252+
//((int(1952/128) % 8 )^(int(1952/16) %8))*16+ int(1952/128)*128 -> 2000
253+
%offset = arith.constant 1952 : index
254+
%1 = vector.load %0[%offset] : memref<?xi8>, vector<16xi8>
255+
return %1: vector<16xi8>
256+
}
257+
258+
// CHECK-LABEL: func @swizzle_load_xor
259+
// CHECK-SAME: %[[SRC:[A-Za-z0-9]+]]: memref<?xi8>
260+
// CHECK: %[[SWOFF:.+]] = arith.constant 2000 : index
261+
// CHECK: %[[VECTOR:.+]] = vector.load %[[SRC]][%[[SWOFF]]]
262+
// CHECK: return %[[VECTOR]]
263+
264+
// -----
265+
266+
func.func @swizzle_load_xor_phase2(%src: memref<?xi8>) -> vector<16xi8> {
267+
%0 = iree_codegen.swizzle_hint %src[#iree_codegen.xor_shuffle<128, 16, 128, 2>] : memref<?xi8>
268+
269+
%offset = arith.constant 1056 : index
270+
%1 = vector.load %0[%offset] : memref<?xi8>, vector<16xi8>
271+
return %1: vector<16xi8>
272+
}
273+
274+
// CHECK-LABEL: func @swizzle_load_xor_phase2
275+
// CHECK-SAME: %[[SRC:[A-Za-z0-9]+]]: memref<?xi8>
276+
// CHECK: %[[SWOFF:.+]] = arith.constant 1120 : index
277+
// CHECK: %[[VECTOR:.+]] = vector.load %[[SRC]][%[[SWOFF]]]
278+
// CHECK: return %[[VECTOR]]
279+
280+
// -----
281+
282+
283+
func.func @swizzle_raw_buffer_to_lds(%global : memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>) {
284+
%c0 = arith.constant 0 : index
285+
%c1 = arith.constant 1 : index
286+
//1 row, 3rd tile : 1*8192+2*128 = 8448 -> (0 XOR 1)*16+8448 = 8464
287+
%offset = arith.constant 8448 : index
288+
%lds = memref.alloc() : memref<32768xi8, #gpu.address_space<workgroup>>
289+
%globalSwizzle = iree_codegen.swizzle_hint %global[#iree_codegen.xor_shuffle<128, 16, 8192>] : memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>
290+
amdgpu.gather_to_lds %globalSwizzle[%offset], %lds[%c0]
291+
: vector<16xi8>, memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>, memref<32768xi8, #gpu.address_space<workgroup>>
292+
293+
func.return
294+
}
295+
296+
// CHECK-LABEL: func @swizzle_raw_buffer_to_lds
297+
// CHECK-SAME: %[[SRC:.+]]: memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>
298+
// CHECK: %[[SWOFF:.+]] = arith.constant 8464 : index
299+
// CHECK: %[[LDSOFFSET:.+]] = arith.constant 0 : index
300+
// CHECK: %[[LDS:.+]] = memref.alloc() : memref<32768xi8, #gpu.address_space<workgroup>>
301+
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SWOFF]]], %[[LDS]][%[[LDSOFFSET]]]
302+
303+
// -----
304+
305+
306+
func.func @swizzle_raw_buffer_to_lds_ignore_dst_op(%global : memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>) {
307+
%c0 = arith.constant 0 : index
308+
%c1 = arith.constant 1 : index
309+
//1 row, 3rd tile : 1*8192+2*128 = 8448 -> (0 XOR 1)*16+8448 = 8464
310+
%offset = arith.constant 8448 : index
311+
%lds = memref.alloc() : memref<32768xi8, #gpu.address_space<workgroup>>
312+
%ldsSwizzle = iree_codegen.swizzle_hint %lds[#iree_codegen.xor_shuffle<128, 16>] : memref<32768xi8, #gpu.address_space<workgroup>>
313+
%globalSwizzle = iree_codegen.swizzle_hint %global[#iree_codegen.xor_shuffle<128, 16, 8192>] : memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>
314+
amdgpu.gather_to_lds %globalSwizzle[%offset], %ldsSwizzle[%c0]
315+
: vector<16xi8>, memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>, memref<32768xi8, #gpu.address_space<workgroup>>
316+
317+
func.return
318+
}
319+
320+
// CHECK-LABEL: func @swizzle_raw_buffer_to_lds_ignore_dst_op
321+
// CHECK-SAME: %[[SRC:.+]]: memref<32768xi8, #amdgpu.address_space<fat_raw_buffer>>
322+
// CHECK: %[[SWOFF:.+]] = arith.constant 8464 : index
323+
// CHECK: %[[LDSOFFSET:.+]] = arith.constant 0 : index
324+
// CHECK: %[[LDS:.+]] = memref.alloc() : memref<32768xi8, #gpu.address_space<workgroup>>
325+
// CHECK: amdgpu.gather_to_lds %[[SRC]][%[[SWOFF]]], %[[LDS]][%[[LDSOFFSET]]]

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,117 @@ SymbolicUKernelProviderAttr::getMLIRUKernel(StringRef name, DictionaryAttr,
612612
return symbolTable.lookup(name);
613613
}
614614

615+
//===---------------------------------------------------------------------===//
616+
// iree_codegen.xor_shuffle
617+
//===---------------------------------------------------------------------===//
618+
619+
/// Extract column index for XOR swizzling.
620+
/// ((id%rowStride) / accessWidth)
621+
static Value extractCol(OpBuilder &builder, Location loc, OpFoldResult id,
622+
OpFoldResult rowAlignment, OpFoldResult accessWidth) {
623+
AffineExpr d0, s0, s1;
624+
bindDims(builder.getContext(), d0);
625+
bindSymbols(builder.getContext(), s0, s1);
626+
AffineExpr result = (d0 % s0).floorDiv(s1);
627+
return getValueOrCreateConstantIndexOp(
628+
builder, loc,
629+
affine::makeComposedFoldedAffineApply(builder, loc, result,
630+
{id, rowAlignment, accessWidth}));
631+
}
632+
633+
/// Extract row index for XOR swizzling.
634+
/// row = ((id/rowStride) / perPhase ) % rowAccessAlignment
635+
static Value extractRow(OpBuilder &builder, Location loc, OpFoldResult id,
636+
OpFoldResult rowStride, OpFoldResult perPhase,
637+
OpFoldResult rowAccessAlignment) {
638+
AffineExpr d0, s0, s1, s2;
639+
bindDims(builder.getContext(), d0);
640+
bindSymbols(builder.getContext(), s0, s1, s2);
641+
AffineExpr result = (d0.floorDiv(s0).floorDiv(s1)) % s2;
642+
return getValueOrCreateConstantIndexOp(
643+
builder, loc,
644+
affine::makeComposedFoldedAffineApply(
645+
builder, loc, result, {id, rowStride, perPhase, rowAccessAlignment}));
646+
}
647+
648+
/// Swizzle column on id.
649+
/// new_id = id-id%rowAlignmentVal+colSwizzled*accessWidthVal
650+
static Value updateCol(OpBuilder &builder, Location loc, OpFoldResult id,
651+
Value colSwizzled, OpFoldResult rowAlignment,
652+
OpFoldResult accessWidth) {
653+
AffineExpr d0, d1, s0, s1;
654+
bindDims(builder.getContext(), d0, d1);
655+
bindSymbols(builder.getContext(), s0, s1);
656+
AffineExpr result = d0 - d0 % s0 + d1 * s1;
657+
return getValueOrCreateConstantIndexOp(
658+
builder, loc,
659+
affine::makeComposedFoldedAffineApply(
660+
builder, loc, result, {id, colSwizzled, rowAlignment, accessWidth}));
661+
}
662+
663+
OpFoldResult XORShuffleAttr::swizzleOffset(OpBuilder &b, Location loc,
664+
OpFoldResult offset,
665+
Value src) const {
666+
int64_t rotationInvariant =
667+
getRowWidth() * (getRowWidth() / getAccessWidth());
668+
int64_t rowStride =
669+
getRowStride() != int64_t() ? getRowStride() : getRowWidth();
670+
int64_t perPhase = getPerPhase() != int64_t() ? getPerPhase() : 1;
671+
672+
OpFoldResult id =
673+
getMinimumConstantOffsetValue(b, loc, offset, rotationInvariant);
674+
Value idVal = getValueOrCreateConstantIndexOp(b, loc, id);
675+
676+
// Number of elements per row.
677+
Value rowAlignmentVal = b.create<arith::ConstantIndexOp>(loc, getRowWidth());
678+
// Number of elements per group.
679+
Value accessWidthVal =
680+
b.create<arith::ConstantIndexOp>(loc, getAccessWidth());
681+
// Number of rows per phase.
682+
Value perPhaseVal = b.create<arith::ConstantIndexOp>(loc, perPhase);
683+
// Buffer stride.
684+
Value rowStrideVal = b.create<arith::ConstantIndexOp>(loc, rowStride);
685+
// Number of contiguous groups of elements per row (swizzled together).
686+
Value rowAccessAlignmentVal =
687+
b.create<arith::ConstantIndexOp>(loc, getRowWidth() / getAccessWidth());
688+
689+
Value colVal = extractCol(b, loc, idVal, rowAlignmentVal, accessWidthVal);
690+
Value rowVal = extractRow(b, loc, idVal, rowStrideVal, perPhaseVal,
691+
rowAccessAlignmentVal);
692+
auto colSwizzled = b.create<arith::XOrIOp>(loc, rowVal, colVal);
693+
694+
// Update colSwizzled to initial id
695+
Value swizzledIdVal =
696+
updateCol(b, loc, idVal, colSwizzled, rowAlignmentVal, accessWidthVal);
697+
Value diff = b.create<arith::SubIOp>(loc, swizzledIdVal, idVal);
698+
return b
699+
.create<arith::AddIOp>(
700+
loc, getValueOrCreateConstantIndexOp(b, loc, offset), diff)
701+
.getResult();
702+
}
703+
704+
int64_t XORShuffleAttr::getAccessElementCount() const {
705+
return getAccessWidth();
706+
}
707+
708+
LogicalResult
709+
XORShuffleAttr::verify(function_ref<InFlightDiagnostic()> emitError,
710+
int64_t rowWidth, int64_t accessWidth, int64_t rowStride,
711+
int64_t perPhase) {
712+
if (rowWidth % accessWidth != 0) {
713+
return emitError() << "expected access width to divide row width";
714+
}
715+
int64_t maxPhase = rowWidth / accessWidth;
716+
if (perPhase > maxPhase) {
717+
return emitError() << "per_phase must be smaller than max_phase";
718+
}
719+
if (rowStride % rowWidth != 0) {
720+
return emitError() << "expected row width to divide row stride";
721+
}
722+
723+
return success();
724+
}
725+
615726
//===----------------------------------------------------------------------===//
616727
// Initialize attributes
617728
//===----------------------------------------------------------------------===//

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,4 +549,73 @@ def IREECodegen_UKernelDescriptorAttr :
549549
}];
550550
}
551551

552+
553+
//===---------------------------------------------------------------------===//
554+
// iree_codegen.xor_shuffle
555+
//===---------------------------------------------------------------------===//
556+
557+
def IREECodegen_XORShuffleAttr :
558+
AttrDef<IREECodegen_Dialect, "XORShuffle", [
559+
DeclareAttrInterfaceMethods<IREECodegen_SwizzleAttrInterface, [
560+
"swizzleOffset",
561+
"getAccessElementCount"
562+
]>
563+
]> {
564+
let mnemonic = "xor_shuffle";
565+
let summary = "An attribute that describes an XOR-based swizzling pattern.";
566+
let description = [{
567+
Shuffles accesses of |access_width| within rows of size
568+
|row_width|. For any given access into logical memref of shape
569+
`memref<...xNx|access_width|x!eltype>` where `N = row_width / access_width`
570+
at position `(i, j, 0)` is shuffled to `(i, ((i/per_phase) %N) XOR j , 0)`. For example,
571+
572+
```
573+
row_width = 16, access_width = 4, per_phase = 1
574+
575+
0000 1111 2222 3333 /// 0 1 2 3
576+
4444 5555 6666 7777 /// 0 1 2 3
577+
8888 9999 AAAA BBBB /// 0 1 2 3
578+
CCCC DDDD EEEE FFFF /// 0 1 2 3
579+
```
580+
581+
is swizzled to
582+
```
583+
0000 1111 2222 3333 /// 0 1 2 3
584+
7777 4444 5555 6666 /// 1 0 3 2
585+
BBBB AAAA 8888 9999 /// 2 3 0 1
586+
FFFF EEEE DDDD CCCC /// 3 2 1 0
587+
```
588+
|access_width| allows to keep the same shuffling accross multiple rows. For example,
589+
590+
```
591+
row_width = 16, access_width = 4, per_phase = 2
592+
593+
0000 1111 2222 3333 /// 0 1 2 3
594+
4444 5555 6666 7777 /// 0 1 2 3
595+
8888 9999 AAAA BBBB /// 0 1 2 3
596+
CCCC DDDD EEEE FFFF /// 0 1 2 3
597+
```
598+
599+
is swizzled to
600+
```
601+
0000 1111 2222 3333 /// 0 1 2 3
602+
7777 4444 5555 6666 /// 0 1 2 3
603+
BBBB AAAA 8888 9999 /// 1 0 3 2
604+
FFFF EEEE DDDD CCCC /// 1 0 3 2
605+
```
606+
607+
The pattern repeats for subsequent rows.
608+
}];
609+
let parameters = (ins
610+
AttrParameter<"int64_t", "">:$row_width,
611+
AttrParameter<"int64_t", "">:$access_width,
612+
OptionalParameter<"int64_t", "row stride. Default to row_width">:$row_stride,
613+
OptionalParameter<"int64_t", "Default to 1">:$per_phase
614+
);
615+
let assemblyFormat = [{
616+
`<` $row_width `,` $access_width (`,` $row_stride^)? (`,` $per_phase^)? `>`
617+
}];
618+
let genVerifyDecl = 1;
619+
}
620+
552621
#endif // IREE_COMPILER_CODEGEN_DIALECT_IREECODEGENATTRS

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,8 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,
12081208

12091209
if (forROCDL) {
12101210
funcPassManager.addPass(amdgpu::createAmdgpuMaskedloadToLoadPass);
1211+
// This pass needs to run before the ResolveSwizzleHints pass.
1212+
funcPassManager.addPass(amdgpu::createAmdgpuFoldMemRefOpsPass);
12111213
}
12121214

12131215
// This pass needs to run before SCF -> CF.

0 commit comments

Comments
 (0)