Skip to content

Commit 77c9289

Browse files
Revert "[AMD] Reuse explicit swizzling pattern for ConvertLayoutOp (#7636)"
This reverts commit 8cb3a83.
1 parent 629e057 commit 77c9289

File tree

5 files changed

+183
-259
lines changed

5 files changed

+183
-259
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -557,14 +557,13 @@ Value emitPadding(Location loc, RewriterBase &rewriter,
557557
// calcPaddedOffset is a lambda that takes a base offset (mlir::Value)
558558
// and computes a new offset (mlir::Value) by applying padding based on
559559
// shared memory layout.
560-
SmallVector<Value>
561-
lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
562-
ArrayRef<Value> valsArray, // Input for store, output for load
563-
Type llvmElemTy, Value smemBase,
564-
std::function<Value(Value)> calcPaddedOffset,
565-
Value affineOffset, uint64_t maskSpanAffineOffset,
566-
RewriterBase &rewriter, const TargetInfoBase &targetInfo,
567-
Operation *localLoadOp = nullptr);
560+
SmallVector<Value> lowerLdStShared(
561+
Location loc, MLIRContext *ctx, LinearLayout cvt,
562+
ArrayRef<Value> valsArray, // Input for store, output for load
563+
Type llvmElemTy, Value smemBase,
564+
std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
565+
uint64_t maskSpanAffineOffset, ConversionPatternRewriter &rewriter,
566+
const TargetInfoBase &targetInfo, Operation *localLoadOp = nullptr);
568567

569568
// Lower an ld/st-like operation given a layout and a callback that creates the
570569
// PTX instruction Lowers to st when valArrays is empty, and to ld when it is
@@ -577,10 +576,10 @@ SmallVector<Value> lowerLdSt(
577576
ArrayRef<Value> valsArray, // Input for store, output for load
578577
Type llvmElemTy, Value smemBase,
579578
std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
580-
uint64_t maskSpanAffineOffset, RewriterBase &rewriter,
579+
uint64_t maskSpanAffineOffset, ConversionPatternRewriter &rewriter,
581580
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
582-
std::function<SmallVector<Value>(RewriterBase &, Location, ArrayRef<Value>,
583-
Value, int, VectorType)>
581+
std::function<SmallVector<Value>(ConversionPatternRewriter &, Location,
582+
ArrayRef<Value>, Value, int, VectorType)>
584583
lowerInst);
585584

586585
// Lower local_load/local_store via ld.shared/st.shared
@@ -589,7 +588,7 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx,
589588
LinearLayout cvt, // Map from registers to offset
590589
ArrayRef<Value> valsArray, // Input for store, empty for load
591590
Type llvmElemTy, triton::gpu::MemDescType srcTy,
592-
SharedMemoryObject smemObj, RewriterBase &rewriter,
591+
SharedMemoryObject smemObj, ConversionPatternRewriter &rewriter,
593592
const TargetInfoBase &targetInfo,
594593
Operation *localLoadOp = nullptr);
595594

@@ -644,12 +643,6 @@ Value transferWithinBlockPadding(triton::gpu::ConvertLayoutOp op, Value src,
644643
const LLVMTypeConverter *typeConverter,
645644
RewriterBase &rewriter);
646645

647-
LogicalResult
648-
transferWithinBlockSwizzling(triton::gpu::ConvertLayoutOp op, Value src,
649-
const TargetInfoBase &targetInfo,
650-
const LLVMTypeConverter *typeConverter,
651-
RewriterBase &rewriter);
652-
653646
SmallVector<Value> inlineRegionImpl(RewriterBase &rewriter, Region &region,
654647
ArrayRef<Value> args,
655648
mlir::TypeID terminatorTypeId,

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 164 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,167 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
110110
return success();
111111
}
112112

113+
SmallVector<Value> transferWithinBlockSwizzlingImpl(
114+
Location loc, ConversionPatternRewriter &rewriter,
115+
const LinearLayout &srcLayout, const LinearLayout &dstLayout,
116+
ArrayRef<Value> inVals, Type llvmElemTy, Value smemBase) const {
117+
auto *ctx = rewriter.getContext();
118+
auto b = TritonLLVMOpBuilder(loc, rewriter);
119+
// We handle transformations recursively as they all need a preprocessing
120+
// and a postprocessing step.
121+
122+
// Handle pointer types as 64-bit integers
123+
if (isa<LLVM::LLVMPointerType>(llvmElemTy)) {
124+
auto llvmElemTyPtr = i64_ty;
125+
auto newInVals = llvm::to_vector(llvm::map_range(inVals, [&](Value v) {
126+
return b.ptrtoint(llvmElemTyPtr, v).getResult();
127+
}));
128+
auto outVals =
129+
transferWithinBlockSwizzlingImpl(loc, rewriter, srcLayout, dstLayout,
130+
newInVals, llvmElemTyPtr, smemBase);
131+
for (auto &v : outVals) {
132+
v = b.inttoptr(llvmElemTy, v);
133+
}
134+
return outVals;
135+
}
136+
137+
// Handle sub-byte elements like i1
138+
if (llvmElemTy.getIntOrFloatBitWidth() < 8) {
139+
// Upcast to i8
140+
auto i8ElemTy = i8_ty;
141+
auto newInVals = llvm::to_vector(llvm::map_range(
142+
inVals, [&](Value v) { return b.zext(i8ElemTy, v).getResult(); }));
143+
auto outVals = transferWithinBlockSwizzlingImpl(
144+
loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
145+
for (auto &v : outVals) {
146+
v = b.trunc(llvmElemTy, v);
147+
}
148+
return outVals;
149+
}
150+
151+
// Remove broadcasting in src
152+
auto removeBroadcastSrc = actionRemoveBroadcastedRegs(srcLayout);
153+
if (!removeBroadcastSrc.isIdentity()) {
154+
auto prmtSrc = removeBroadcastSrc.apply(srcLayout);
155+
auto newInVals = removeBroadcastSrc.apply(inVals);
156+
return transferWithinBlockSwizzlingImpl(loc, rewriter, prmtSrc, dstLayout,
157+
newInVals, llvmElemTy, smemBase);
158+
}
159+
160+
// Remove broadcasting in dst
161+
auto removeBroadcastDst = actionRemoveBroadcastedRegs(dstLayout);
162+
if (!removeBroadcastDst.isIdentity()) {
163+
auto prmtDst = removeBroadcastDst.apply(dstLayout);
164+
auto outVals = transferWithinBlockSwizzlingImpl(
165+
loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase);
166+
return broadcastAs(outVals, dstLayout);
167+
}
168+
169+
// At this point we have a type that's at least 8-bit
170+
// and we don't have broadcasting in the registers
171+
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
172+
auto smem = optimalSwizzling(srcLayout, dstLayout, bitwidth);
173+
174+
// Extract reps from smem
175+
auto kReg = str_attr("register");
176+
auto kReps = str_attr("reps");
177+
auto nReps = smem.getInDimSize(kReps);
178+
auto reps = LinearLayout::identity1D(nReps, kReg, kReps);
179+
180+
auto totalStoreCvt = srcLayout.invertAndCompose(smem);
181+
auto totalLoadCvt = dstLayout.invertAndCompose(smem);
182+
183+
// The permutation exists by construction of the reps dimension in
184+
// optimalSwizzling
185+
auto permStore =
186+
regPermForDivide(totalStoreCvt, reps, /*left=*/false).value();
187+
totalStoreCvt = permStore.apply(totalStoreCvt);
188+
auto permutedInVals = permStore.apply(inVals);
189+
auto permLoad =
190+
regPermForDivide(totalLoadCvt, reps, /*left=*/false).value();
191+
totalLoadCvt = permLoad.apply(totalLoadCvt);
192+
193+
// Remove the reps and flatten into offset
194+
auto storeCvt = *divideRight(totalStoreCvt, reps);
195+
auto loadCvt = *divideRight(totalLoadCvt, reps);
196+
auto kOffset = str_attr("offset");
197+
storeCvt = storeCvt.reshapeOuts({{kOffset, storeCvt.getTotalOutDimSize()}});
198+
loadCvt = loadCvt.reshapeOuts({{kOffset, loadCvt.getTotalOutDimSize()}});
199+
200+
auto tileSize = storeCvt.getInDimSize(kReg);
201+
202+
assert(permutedInVals.size() == tileSize * nReps);
203+
SmallVector<Value> outVals;
204+
auto noPaddingOffset = [](Value v) { return v; };
205+
auto affineOffset = b.i32_val(0);
206+
auto maskSpanAffineOffset = 0;
207+
for (int i = 0; i < nReps; ++i) {
208+
if (i > 0)
209+
b.barrier();
210+
211+
auto tileInVals =
212+
ArrayRef<Value>(permutedInVals).slice(i * tileSize, tileSize);
213+
// Store
214+
lowerLdStShared(loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
215+
noPaddingOffset, affineOffset, maskSpanAffineOffset,
216+
rewriter, targetInfo);
217+
b.barrier();
218+
// Load
219+
SmallVector<Value> tileOutVals = lowerLdStShared(
220+
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, noPaddingOffset,
221+
affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
222+
llvm::append_range(outVals, tileOutVals);
223+
}
224+
225+
// Undo the permLoad used to divideRight
226+
outVals = permLoad.inverse().apply(outVals);
227+
return outVals;
228+
}
229+
230+
LogicalResult
231+
transferWithinBlockSwizzling(ConvertLayoutOp op, Value src,
232+
ConversionPatternRewriter &rewriter) const {
233+
// Fallback for now to standard lowering if it can use stmatrix
234+
auto scratchConfig =
235+
getScratchConfigForCvt(op.getSrc().getType(), op.getType());
236+
bool isStMatrix = targetInfo.canUseStMatrix(
237+
op.getSrc().getType(), scratchConfig.repShape,
238+
scratchConfig.paddedRepShape, scratchConfig.order,
239+
/*swizzleByteSize=*/0);
240+
if (isStMatrix) {
241+
return failure();
242+
}
243+
244+
auto loc = op.getLoc();
245+
auto *ctx = op.getContext();
246+
auto srcTy = op.getSrc().getType();
247+
auto dstTy = op.getType();
248+
249+
// Remove the kBlock dimension from the layout as it's the identity in the
250+
// cvt
251+
auto srcLayout = toLinearLayout(srcTy);
252+
auto dstLayout = toLinearLayout(dstTy);
253+
auto kReg = str_attr("register");
254+
auto kLane = str_attr("lane");
255+
auto kWarp = str_attr("warp");
256+
srcLayout = srcLayout.sublayout({kReg, kLane, kWarp},
257+
to_vector(srcLayout.getOutDimNames()));
258+
dstLayout = dstLayout.sublayout({kReg, kLane, kWarp},
259+
to_vector(dstLayout.getOutDimNames()));
260+
261+
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
262+
auto smemBase =
263+
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
264+
auto inVals = unpackLLElements(loc, src, rewriter);
265+
auto outVals = transferWithinBlockSwizzlingImpl(
266+
loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase);
267+
268+
Value result =
269+
packLLElements(loc, getTypeConverter(), outVals, rewriter, dstTy);
270+
rewriter.replaceOp(op, result);
271+
return success();
272+
}
273+
113274
LogicalResult transferWithinBlock(ConvertLayoutOp op,
114275
const LinearLayout &srcLayout,
115276
const LinearLayout &dstLayout,
@@ -118,8 +279,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
118279
assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
119280

120281
// Try to use swizzling to implement the conversion
121-
if (succeeded(transferWithinBlockSwizzling(op, adaptor.getSrc(), targetInfo,
122-
getTypeConverter(), rewriter))) {
282+
// HACK Remove once AMD tests pass for the swizzling path
283+
if (targetInfo.isCuda() && succeeded(transferWithinBlockSwizzling(
284+
op, adaptor.getSrc(), rewriter))) {
123285
return success();
124286
}
125287

0 commit comments

Comments
 (0)