Skip to content

Commit 07478c2

Browse files
authored
[AMD] Support shared encoding swizzle for BufferLoadToLocal (triton-lang#6329)
Adds lowering capabilities for actual swizzled shared encodings. Because our direct to lds loads must write coalesced we have to apply the swizzling to the global pointers. This is done by swizzling the global addresses via `permute` between lanes based on the swizzle pattern. In the future we might want to apply the swizzling by changing the source layout so each lane directly computes the right address. But currently this does not work as Ops like `ExpandDim` do not work if the distributed layout moves in 2 dimensions in a basis. Support for non global loads to lds will be a separate PR after this has landed.
1 parent 52e7a4b commit 07478c2

File tree

4 files changed

+283
-35
lines changed

4 files changed

+283
-35
lines changed

test/Conversion/amd/buffer_load_to_local_to_llvm.mlir

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,123 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
162162
tt.return
163163
}
164164
}
165+
166+
// -----
167+
168+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [8, 1], order = [1, 0]}>
169+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}>
170+
#smem = #ttg.shared_memory
171+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
172+
// COMMON-LABEL: buffer_load_swizzled_simple
173+
tt.func public @buffer_load_swizzled_simple(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
174+
%arg1: !tt.ptr<f16>,
175+
%arg2: tensor<16x64xi32, #blocked>,
176+
%arg3: !ttg.memdesc<16x64xf16, #shared, #smem, mutable>) {
177+
// Each thread needs to load 2 elements and we load 1 (sizePerThread) per buffer load instruction
178+
// COMMON: rocdl.make.buffer.rsrc
179+
// COMMON-NOT: rocdl.make.buffer.rsrc
180+
// COMMON: rocdl.ds_bpermute
181+
// COMMON: rocdl.raw.ptr.buffer.load.lds
182+
// COMMON: rocdl.ds_bpermute
183+
// COMMON: rocdl.raw.ptr.buffer.load.lds
184+
// COMMON-NOT: rocdl.raw.ptr.buffer.load.lds
185+
%65 = amdgpu.buffer_load_to_local %arg1[%arg2] into %arg3 {OpIdx = #amdgpu.OpIdx<1>} : <f16>[tensor<16x64xi32, #blocked>] -> <16x64xf16, #shared, #smem, mutable>
186+
tt.return
187+
}
188+
}
189+
190+
// -----
191+
192+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
193+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 2, maxPhase = 8, order = [1, 0]}>
194+
#smem = #ttg.shared_memory
195+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
196+
// COMMON-LABEL: buffer_load_to_local_swizzled_mask_other
197+
tt.func public @buffer_load_to_local_swizzled_mask_other(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
198+
%arg1: !tt.ptr<f16>,
199+
%arg2: tensor<32x32xi32, #blocked>,
200+
%arg3: !ttg.memdesc<32x32xf16, #shared, #smem, mutable>,
201+
%arg4: i32) {
202+
// We need the splat to allow the AxisAnalysis to work during lowering
203+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<32x32xf16, #blocked>
204+
%c0_i32 = arith.constant 0 : i32
205+
%c32_i32 = arith.constant 32 : i32
206+
%c31_i32 = arith.constant 31 : i32
207+
%1 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>, #blocked>
208+
%29 = arith.addi %arg4, %c31_i32 : i32
209+
%30 = arith.divsi %29, %c32_i32 : i32
210+
%31 = arith.cmpi sgt, %30, %c0_i32 : i32
211+
212+
%51 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
213+
%52 = tt.expand_dims %51 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
214+
%65 = tt.splat %arg4 : i32 -> tensor<32x1xi32, #blocked>
215+
%66 = arith.cmpi slt, %52, %65 : tensor<32x1xi32, #blocked>
216+
%67 = tt.broadcast %66 : tensor<32x1xi1, #blocked> -> tensor<32x32xi1, #blocked>
217+
218+
%70 = tt.splat %31 : i1 -> tensor<32x32xi1, #blocked>
219+
%71 = arith.andi %70, %67 : tensor<32x32xi1, #blocked>
220+
221+
// Each thread needs to load 4 elements and we load 1 (sizePerThread) per buffer load instruction
222+
// Note that mask/other alignment is 1 so we need 4 conditionals
223+
224+
// COMMON: rocdl.ds_bpermute
225+
// COMMON: rocdl.ballot
226+
// COMMON: rocdl.raw.ptr.buffer.load.lds
227+
// COMMON: _predicated_store
228+
229+
// COMMON: rocdl.ds_bpermute
230+
// COMMON: rocdl.ballot
231+
// COMMON: rocdl.raw.ptr.buffer.load.lds
232+
// COMMON: _predicated_store
233+
234+
// COMMON: rocdl.ds_bpermute
235+
// COMMON: rocdl.ballot
236+
// COMMON: rocdl.raw.ptr.buffer.load.lds
237+
// COMMON: _predicated_store
238+
239+
// COMMON: rocdl.ds_bpermute
240+
// COMMON: rocdl.ballot
241+
// COMMON: rocdl.raw.ptr.buffer.load.lds
242+
// COMMON: _predicated_store
243+
244+
// COMMON-NOT: rocdl.ds_bpermute
245+
// COMMON-NOT: rocdl.ballot
246+
// COMMON-NOT: rocdl.raw.ptr.buffer.load.lds
247+
// COMMON-NOT: _predicated_store
248+
249+
amdgpu.buffer_load_to_local %arg1[%arg2] mask=%67 other=%cst_0 into %arg3 {OpIdx = #amdgpu.OpIdx<1>} : <f16>[tensor<32x32xi32, #blocked>] tensor<32x32xf16, #blocked> -> <32x32xf16, #shared, #smem, mutable>
250+
tt.return
251+
}
252+
}
253+
254+
// -----
255+
256+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 32], order = [0, 1]}>
257+
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 16, order = [0, 1]}>
258+
#smem = #ttg.shared_memory
259+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.shared = 0 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
260+
// COMMON-LABEL: buffer_load_to_local_swizzled_vectorized_8xf16
261+
tt.func public @buffer_load_to_local_swizzled_vectorized_8xf16(%arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !ttg.memdesc<64x64xf16, #shared, #smem, mutable>) {
262+
%cst = arith.constant dense<64> : tensor<1x64xi32, #blocked>
263+
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
264+
%1 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
265+
%2 = tt.expand_dims %0 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
266+
%3 = tt.broadcast %2 : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked>
267+
%4 = tt.expand_dims %1 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
268+
%5 = arith.muli %4, %cst : tensor<1x64xi32, #blocked>
269+
%6 = tt.broadcast %5 : tensor<1x64xi32, #blocked> -> tensor<64x64xi32, #blocked>
270+
%7 = arith.addi %3, %6 : tensor<64x64xi32, #blocked>
271+
272+
// Each thread needs to load 8 elements and we load 8 (sizePerThread) per buffer load instruction
273+
// GFX950: rocdl.make.buffer.rsrc
274+
// GFX950: rocdl.ds_bpermute
275+
// GFX950: rocdl.raw.ptr.buffer.load.lds
276+
// GFX950-NOT: rocdl.raw.ptr.buffer.load.lds
277+
278+
// GFX942 does not support vectorization > 4bytes so we cannot lower it
279+
// GFX942-NOT: rocdl.raw.ptr.buffer.load.lds
280+
// GFX942: amdgpu.buffer_load_to_local
281+
%8 = amdgpu.buffer_load_to_local %arg1[%7] into %arg2 : <f16>[tensor<64x64xi32, #blocked>] -> <64x64xf16, #shared, #smem, mutable>
282+
tt.return
283+
}
284+
}

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 115 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -420,31 +420,90 @@ struct BufferLoadToLocalOpConversion
420420
if (llOther)
421421
otherElems = unpackLLElements(loc, llOther, rewriter);
422422

423-
// buffer_load into LDS does not support per lane offsets.
424-
// We need to ensure that we write coalesced into shared memory.
425423
auto dstTy = op.getDest().getType();
426-
if (!LLVM::AMD::canCoalesceWriteIntoSharedMemory(rewriter, ptrType, dstTy,
427-
vec)) {
424+
auto sharedEnc = cast<SwizzledSharedEncodingAttr>(dstTy.getEncoding());
425+
426+
// buffer_load into LDS does not support per lane shared offsets. We need to
427+
// ensure that we write coalesced into shared memory.
428+
//
429+
// For *non* swizzled shared encodings we check if they result in
430+
// coalesced writes and can then lower them directly to the intrinsics.
431+
//
432+
// For swizzled shared encodings we need to transfer the swizzling to the
433+
// source pointers. For now this is done by swizzling the pointers between
434+
// the lane of a warp via permute. This only works if the swizzle pattern
435+
// does not exchange elements between warps which holds for all our swizzle
436+
// patterns. There is still a check performed to not silently produce wrong
437+
// results if we invalidate the condition in the future
438+
439+
bool hasSwizzling = sharedEnc.getMaxPhase() != 1;
440+
441+
// Compute the blocked -> shared linear layout to check preconditions
442+
auto shape = ptrType.getShape();
443+
LinearLayout srcLayout =
444+
triton::gpu::toLinearLayout(shape, ptrType.getEncoding());
445+
LinearLayout sharedLayout =
446+
triton::gpu::toLinearLayout(shape, dstTy.getEncoding());
447+
LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout);
448+
449+
unsigned threadsPerWarp = lookupThreadsPerWarp(rewriter);
450+
if (!hasSwizzling && !LLVM::AMD::canCoalesceWriteIntoSharedMemory(
451+
rewriter, srcToSharedLayout, threadsPerWarp)) {
452+
return rewriter.notifyMatchFailure(
453+
op, "does not write coalesced into LDS and is not swizzled");
454+
}
455+
456+
if (hasSwizzling && !LLVM::AMD::doesSwizzleInsideWarp(
457+
rewriter, srcToSharedLayout, threadsPerWarp)) {
428458
return rewriter.notifyMatchFailure(op,
429-
"does not write coalesced into LDS");
459+
"does swizzle across warp boundaries");
430460
}
431461

432462
auto resElemTy = getTypeConverter()->convertType(dstTy.getElementType());
433463
auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct(
434464
loc, llDst, resElemTy, rewriter);
435465

436-
// First we determine the vector size per load and collect the
437-
// shared addresses. This will only emit the address calculation and not the
438-
// actual loads
466+
auto emitSharedAddresses = [&](RankedTensorType srcTy, MemDescType dstTy,
467+
SmallVector<Value> &shmemAddrs,
468+
VectorType &vecTy) {
469+
bool ok = emitTransferBetweenRegistersAndShared(
470+
ptrType, dstTy, resElemTy, {}, smemObj, loc, rewriter, targetInfo,
471+
[&](VectorType vecTy_, Value shmemAddr) {
472+
vecTy = vecTy_;
473+
shmemAddrs.push_back(shmemAddr);
474+
});
475+
assert(ok);
476+
};
477+
478+
// Determine the vector size per load and collect the shared addresses. This
479+
// will only emit the address calculation and not the actual loads.
480+
// For swizzled loads we get the non swizzled/coalesced shared addresses
481+
// from a temporary non swizzled layout. Those addresses will be used as the
482+
// store addresses. Additionally, we compute the swizzled shared memory
483+
// addresses which will be used to compute which lane holds the global ptr
484+
// to the coalesced address
439485
VectorType vecTy;
440-
SmallVector<Value> shmemAddrs;
441-
bool ok = emitTransferBetweenRegistersAndShared(
442-
ptrType, dstTy, resElemTy, {}, smemObj, loc, rewriter, targetInfo,
443-
[&](VectorType vecTy_, Value shmemAddr) {
444-
vecTy = vecTy_;
445-
shmemAddrs.push_back(shmemAddr);
446-
});
447-
assert(ok);
486+
SmallVector<Value> coalescedShmemAddr;
487+
SmallVector<Value> swizzledShmemAddr;
488+
489+
if (!hasSwizzling) {
490+
emitSharedAddresses(ptrType, dstTy, coalescedShmemAddr, vecTy);
491+
} else {
492+
emitSharedAddresses(ptrType, dstTy, swizzledShmemAddr, vecTy);
493+
// Create non swizzled/coalesced encoding
494+
auto dstEnc = cast<SwizzledSharedEncodingAttr>(dstTy.getEncoding());
495+
auto flatSharedEnc = SwizzledSharedEncodingAttr::get(
496+
getContext(), dstEnc.getVec(), 1, 1, dstEnc.getOrder(),
497+
dstEnc.getCTALayout());
498+
auto flatDstTy =
499+
MemDescType::get(dstTy.getShape(), dstTy.getElementType(),
500+
flatSharedEnc, dstTy.getMemorySpace());
501+
VectorType coalescedVecTy;
502+
emitSharedAddresses(ptrType, flatDstTy, coalescedShmemAddr,
503+
coalescedVecTy);
504+
assert(coalescedVecTy == vecTy);
505+
}
506+
assert(vecTy.getNumElements() == vec);
448507

449508
int vecBits = vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
450509
if (!targetInfo.supportsDirectToLdsLoadBitWidth(vecBits)) {
@@ -462,17 +521,43 @@ struct BufferLoadToLocalOpConversion
462521
// based on the collected shared addresses and vector size
463522
Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr, llStride);
464523

465-
for (int i = 0; i < shmemAddrs.size(); i++) {
524+
for (int i = 0; i < coalescedShmemAddr.size(); i++) {
466525
auto srcIdx = i * vec;
467526
auto offsetIn = offsetElems[srcIdx];
468-
469527
Value pred = mask ? maskElems[srcIdx] : b.true_val();
528+
529+
if (hasSwizzling) {
530+
// Compute the laneOffset based on the difference in elements between
531+
// the two shmem addresses. laneOffset will be negative for half the
532+
// lanes because a smaller laneId might hold our global_ptr.
533+
auto coalescedAddr = b.ptrtoint(i64_ty, coalescedShmemAddr[i]);
534+
auto swizzledAddr = b.ptrtoint(i64_ty, swizzledShmemAddr[i]);
535+
auto diff = b.trunc(i32_ty, b.sub(swizzledAddr, coalescedAddr));
536+
Value laneOffset = b.sdiv(diff, vecBytesVal);
537+
// selectLane will always stay inside the warp [0,
538+
// threadsPerWarp) because we only swizzle inside a warp
539+
Value selectLane = b.add(getLaneId(rewriter, loc), laneOffset);
540+
541+
offsetIn = targetInfo.shuffleIdx(rewriter, loc, offsetIn, selectLane);
542+
543+
if (mask) {
544+
// To swizzle the mask we can use ballot and then select the bit based
545+
// on the lane id
546+
auto warpMask =
547+
targetInfo.ballot(rewriter, loc, rewriter.getI64Type(), pred);
548+
// Extract the selectLane bit
549+
auto bitMask =
550+
b.lshr(warpMask, b.zext(rewriter.getI64Type(), selectLane));
551+
pred = b.trunc(i1_ty, bitMask);
552+
}
553+
}
554+
470555
bufferEmitter.emitLoadToLds(vecTy, vecBytesVal, rsrcDesc, offsetIn,
471-
shmemAddrs[i], pred, op.getCache());
556+
coalescedShmemAddr[i], pred, op.getCache());
472557
if (!otherElems.empty()) {
473558
Value storeVal = packElementRangeIntoVector(
474559
rewriter, this->getTypeConverter(), loc, vecTy, otherElems, srcIdx);
475-
llStore(rewriter, loc, shmemAddrs[i], storeVal,
560+
llStore(rewriter, loc, coalescedShmemAddr[i], storeVal,
476561
b.icmp_ne(maskElems[srcIdx], b.true_val()), op.getCache());
477562
}
478563
}
@@ -534,11 +619,19 @@ struct AsyncCopyGlobalToLocalOpConversion
534619
auto maskElements = getMaskElemsAndUpdateVeclen(
535620
rewriter, loc, adaptor.getMask(), op.getMask(), maxVec);
536621

622+
auto shape = srcTy.getShape();
623+
LinearLayout srcLayout =
624+
triton::gpu::toLinearLayout(shape, srcTy.getEncoding());
625+
LinearLayout sharedLayout =
626+
triton::gpu::toLinearLayout(shape, dstTy.getEncoding());
627+
LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout);
628+
537629
// global.load.lds does not support per lane offsets.
538630
// We need to ensure that we write coalesced into shared memory. This means
539631
// that the kLane dim needs to be contigeous based on the vector size.
540-
if (!LLVM::AMD::canCoalesceWriteIntoSharedMemory(rewriter, srcTy, dstTy,
541-
maxVec)) {
632+
unsigned threadsPerWarp = lookupThreadsPerWarp(rewriter);
633+
if (!LLVM::AMD::canCoalesceWriteIntoSharedMemory(
634+
rewriter, srcToSharedLayout, threadsPerWarp)) {
542635
return rewriter.notifyMatchFailure(op,
543636
"does not write coalesced into LDS");
544637
}

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -605,20 +605,14 @@ Type scaleDotElemTypeToMLIRType(MLIRContext *ctx, triton::ScaleDotElemType t) {
605605
}
606606

607607
bool canCoalesceWriteIntoSharedMemory(RewriterBase &rewriter,
608-
RankedTensorType srcTy,
609-
triton::gpu::MemDescType dstTy,
610-
unsigned vectorSize) {
611-
auto shape = srcTy.getShape();
612-
LinearLayout srcLayout =
613-
triton::gpu::toLinearLayout(shape, srcTy.getEncoding());
614-
LinearLayout sharedLayout =
615-
triton::gpu::toLinearLayout(shape, dstTy.getEncoding());
616-
LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout);
608+
const LinearLayout &srcToSharedLayout,
609+
unsigned threadsPerWarp) {
610+
auto contig = srcToSharedLayout.getNumConsecutiveInOut();
617611

618612
StringAttr kLane = rewriter.getStringAttr("lane");
619613
for (int inLane : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kLane))) {
620614
auto basis = srcToSharedLayout.getBasis(kLane, inLane)[0];
621-
unsigned expected = vectorSize * (1 << inLane);
615+
unsigned expected = contig * (1 << inLane);
622616
if (basis != expected) {
623617
LDBG("detected uncoalesced layout from blocked to shared in async copy "
624618
"for lane "
@@ -627,6 +621,42 @@ bool canCoalesceWriteIntoSharedMemory(RewriterBase &rewriter,
627621
return false;
628622
}
629623
}
624+
// Additionally we could swizzle based on the warp dimension so we need to
625+
// check that when all bases are divided by contig, none of the first
626+
// (log2(warpSize) + 1) bits are set to 1
627+
assert(llvm::isPowerOf2_32(threadsPerWarp));
628+
assert(llvm::isPowerOf2_32(contig));
629+
unsigned mask = (threadsPerWarp * contig) - 1;
630+
StringAttr kWarp = rewriter.getStringAttr("warp");
631+
for (int inWarp : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kWarp))) {
632+
auto basis = srcToSharedLayout.getBasis(kWarp, inWarp)[0];
633+
if ((basis & mask) != 0) {
634+
LDBG("detected uncoalesced layout from blocked to shared in async copy "
635+
"for warp "
636+
<< inWarp);
637+
return false;
638+
}
639+
}
640+
641+
return true;
642+
}
643+
644+
bool doesSwizzleInsideWarp(RewriterBase &rewriter,
645+
const LinearLayout &srcToSharedLayout,
646+
unsigned threadsPerWarp) {
647+
auto contig = srcToSharedLayout.getNumConsecutiveInOut();
648+
// If all bases in lane dimension are below threadsPerWarp multiplied with the
649+
// contiguity we do not swizzle across warp boundaries.
650+
assert(llvm::isPowerOf2_32(threadsPerWarp));
651+
unsigned upperLimit = threadsPerWarp * contig;
652+
653+
StringAttr kLane = rewriter.getStringAttr("lane");
654+
for (int inLane : llvm::seq(srcToSharedLayout.getInDimSizeLog2(kLane))) {
655+
auto basis = srcToSharedLayout.getBasis(kLane, inLane)[0];
656+
if (basis >= upperLimit) {
657+
return false;
658+
}
659+
}
630660
return true;
631661
}
632662

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,14 @@ Type scaleDotElemTypeToMLIRType(MLIRContext *ctx, triton::ScaleDotElemType t);
8888
// Returns true if we can perform coalesced write from the source encoding to
8989
// the destination encoding.
9090
bool canCoalesceWriteIntoSharedMemory(RewriterBase &rewriter,
91-
RankedTensorType srcTy,
92-
triton::gpu::MemDescType dstTy,
93-
unsigned vectorSize);
91+
const LinearLayout &srcToSharedLayout,
92+
unsigned threadsPerWarp);
93+
94+
// Returns true if the swizzling pattern does only swizzle the shared memory
95+
// offsets of a warp and does not exchange destination elements across warps
96+
bool doesSwizzleInsideWarp(RewriterBase &rewriter,
97+
const LinearLayout &srcToSharedLayout,
98+
unsigned threadsPerWarp);
9499

95100
// Return true if op is used by DotScaledOp or UpcastMXFPOp ops.
96101
bool isUsedByDotScaledOp(Operation *op);

0 commit comments

Comments
 (0)