Skip to content

Commit 629e057

Browse files
Merge commit '8cb3a831eefcb5d34a8e4b9a8e6129e2c4feec43'
2 parents 7e20e48 + 8cb3a83 commit 629e057

File tree

5 files changed

+259
-183
lines changed

5 files changed

+259
-183
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -557,13 +557,14 @@ Value emitPadding(Location loc, RewriterBase &rewriter,
557557
// calcPaddedOffset is a lambda that takes a base offset (mlir::Value)
558558
// and computes a new offset (mlir::Value) by applying padding based on
559559
// shared memory layout.
560-
SmallVector<Value> lowerLdStShared(
561-
Location loc, MLIRContext *ctx, LinearLayout cvt,
562-
ArrayRef<Value> valsArray, // Input for store, output for load
563-
Type llvmElemTy, Value smemBase,
564-
std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
565-
uint64_t maskSpanAffineOffset, ConversionPatternRewriter &rewriter,
566-
const TargetInfoBase &targetInfo, Operation *localLoadOp = nullptr);
560+
SmallVector<Value>
561+
lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
562+
ArrayRef<Value> valsArray, // Input for store, output for load
563+
Type llvmElemTy, Value smemBase,
564+
std::function<Value(Value)> calcPaddedOffset,
565+
Value affineOffset, uint64_t maskSpanAffineOffset,
566+
RewriterBase &rewriter, const TargetInfoBase &targetInfo,
567+
Operation *localLoadOp = nullptr);
567568

568569
// Lower an ld/st-like operation given a layout and a callback that creates the
569570
// PTX instruction Lowers to st when valArrays is empty, and to ld when it is
@@ -576,10 +577,10 @@ SmallVector<Value> lowerLdSt(
576577
ArrayRef<Value> valsArray, // Input for store, output for load
577578
Type llvmElemTy, Value smemBase,
578579
std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
579-
uint64_t maskSpanAffineOffset, ConversionPatternRewriter &rewriter,
580+
uint64_t maskSpanAffineOffset, RewriterBase &rewriter,
580581
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
581-
std::function<SmallVector<Value>(ConversionPatternRewriter &, Location,
582-
ArrayRef<Value>, Value, int, VectorType)>
582+
std::function<SmallVector<Value>(RewriterBase &, Location, ArrayRef<Value>,
583+
Value, int, VectorType)>
583584
lowerInst);
584585

585586
// Lower local_load/local_store via ld.shared/st.shared
@@ -588,7 +589,7 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx,
588589
LinearLayout cvt, // Map from registers to offset
589590
ArrayRef<Value> valsArray, // Input for store, empty for load
590591
Type llvmElemTy, triton::gpu::MemDescType srcTy,
591-
SharedMemoryObject smemObj, ConversionPatternRewriter &rewriter,
592+
SharedMemoryObject smemObj, RewriterBase &rewriter,
592593
const TargetInfoBase &targetInfo,
593594
Operation *localLoadOp = nullptr);
594595

@@ -643,6 +644,12 @@ Value transferWithinBlockPadding(triton::gpu::ConvertLayoutOp op, Value src,
643644
const LLVMTypeConverter *typeConverter,
644645
RewriterBase &rewriter);
645646

647+
LogicalResult
648+
transferWithinBlockSwizzling(triton::gpu::ConvertLayoutOp op, Value src,
649+
const TargetInfoBase &targetInfo,
650+
const LLVMTypeConverter *typeConverter,
651+
RewriterBase &rewriter);
652+
646653
SmallVector<Value> inlineRegionImpl(RewriterBase &rewriter, Region &region,
647654
ArrayRef<Value> args,
648655
mlir::TypeID terminatorTypeId,

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 2 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -110,167 +110,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
110110
return success();
111111
}
112112

113-
SmallVector<Value> transferWithinBlockSwizzlingImpl(
114-
Location loc, ConversionPatternRewriter &rewriter,
115-
const LinearLayout &srcLayout, const LinearLayout &dstLayout,
116-
ArrayRef<Value> inVals, Type llvmElemTy, Value smemBase) const {
117-
auto *ctx = rewriter.getContext();
118-
auto b = TritonLLVMOpBuilder(loc, rewriter);
119-
// We handle transformations recursively as they all need a preprocessing
120-
// and a postprocessing step.
121-
122-
// Handle pointer types as 64-bit integers
123-
if (isa<LLVM::LLVMPointerType>(llvmElemTy)) {
124-
auto llvmElemTyPtr = i64_ty;
125-
auto newInVals = llvm::to_vector(llvm::map_range(inVals, [&](Value v) {
126-
return b.ptrtoint(llvmElemTyPtr, v).getResult();
127-
}));
128-
auto outVals =
129-
transferWithinBlockSwizzlingImpl(loc, rewriter, srcLayout, dstLayout,
130-
newInVals, llvmElemTyPtr, smemBase);
131-
for (auto &v : outVals) {
132-
v = b.inttoptr(llvmElemTy, v);
133-
}
134-
return outVals;
135-
}
136-
137-
// Handle sub-byte elements like i1
138-
if (llvmElemTy.getIntOrFloatBitWidth() < 8) {
139-
// Upcast to i8
140-
auto i8ElemTy = i8_ty;
141-
auto newInVals = llvm::to_vector(llvm::map_range(
142-
inVals, [&](Value v) { return b.zext(i8ElemTy, v).getResult(); }));
143-
auto outVals = transferWithinBlockSwizzlingImpl(
144-
loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
145-
for (auto &v : outVals) {
146-
v = b.trunc(llvmElemTy, v);
147-
}
148-
return outVals;
149-
}
150-
151-
// Remove broadcasting in src
152-
auto removeBroadcastSrc = actionRemoveBroadcastedRegs(srcLayout);
153-
if (!removeBroadcastSrc.isIdentity()) {
154-
auto prmtSrc = removeBroadcastSrc.apply(srcLayout);
155-
auto newInVals = removeBroadcastSrc.apply(inVals);
156-
return transferWithinBlockSwizzlingImpl(loc, rewriter, prmtSrc, dstLayout,
157-
newInVals, llvmElemTy, smemBase);
158-
}
159-
160-
// Remove broadcasting in dst
161-
auto removeBroadcastDst = actionRemoveBroadcastedRegs(dstLayout);
162-
if (!removeBroadcastDst.isIdentity()) {
163-
auto prmtDst = removeBroadcastDst.apply(dstLayout);
164-
auto outVals = transferWithinBlockSwizzlingImpl(
165-
loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase);
166-
return broadcastAs(outVals, dstLayout);
167-
}
168-
169-
// At this point we have a type that's at least 8-bit
170-
// and we don't have broadcasting in the registers
171-
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
172-
auto smem = optimalSwizzling(srcLayout, dstLayout, bitwidth);
173-
174-
// Extract reps from smem
175-
auto kReg = str_attr("register");
176-
auto kReps = str_attr("reps");
177-
auto nReps = smem.getInDimSize(kReps);
178-
auto reps = LinearLayout::identity1D(nReps, kReg, kReps);
179-
180-
auto totalStoreCvt = srcLayout.invertAndCompose(smem);
181-
auto totalLoadCvt = dstLayout.invertAndCompose(smem);
182-
183-
// The permutation exists by construction of the reps dimension in
184-
// optimalSwizzling
185-
auto permStore =
186-
regPermForDivide(totalStoreCvt, reps, /*left=*/false).value();
187-
totalStoreCvt = permStore.apply(totalStoreCvt);
188-
auto permutedInVals = permStore.apply(inVals);
189-
auto permLoad =
190-
regPermForDivide(totalLoadCvt, reps, /*left=*/false).value();
191-
totalLoadCvt = permLoad.apply(totalLoadCvt);
192-
193-
// Remove the reps and flatten into offset
194-
auto storeCvt = *divideRight(totalStoreCvt, reps);
195-
auto loadCvt = *divideRight(totalLoadCvt, reps);
196-
auto kOffset = str_attr("offset");
197-
storeCvt = storeCvt.reshapeOuts({{kOffset, storeCvt.getTotalOutDimSize()}});
198-
loadCvt = loadCvt.reshapeOuts({{kOffset, loadCvt.getTotalOutDimSize()}});
199-
200-
auto tileSize = storeCvt.getInDimSize(kReg);
201-
202-
assert(permutedInVals.size() == tileSize * nReps);
203-
SmallVector<Value> outVals;
204-
auto noPaddingOffset = [](Value v) { return v; };
205-
auto affineOffset = b.i32_val(0);
206-
auto maskSpanAffineOffset = 0;
207-
for (int i = 0; i < nReps; ++i) {
208-
if (i > 0)
209-
b.barrier();
210-
211-
auto tileInVals =
212-
ArrayRef<Value>(permutedInVals).slice(i * tileSize, tileSize);
213-
// Store
214-
lowerLdStShared(loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
215-
noPaddingOffset, affineOffset, maskSpanAffineOffset,
216-
rewriter, targetInfo);
217-
b.barrier();
218-
// Load
219-
SmallVector<Value> tileOutVals = lowerLdStShared(
220-
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, noPaddingOffset,
221-
affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
222-
llvm::append_range(outVals, tileOutVals);
223-
}
224-
225-
// Undo the permLoad used to divideRight
226-
outVals = permLoad.inverse().apply(outVals);
227-
return outVals;
228-
}
229-
230-
LogicalResult
231-
transferWithinBlockSwizzling(ConvertLayoutOp op, Value src,
232-
ConversionPatternRewriter &rewriter) const {
233-
// Fallback for now to standard lowering if it can use stmatrix
234-
auto scratchConfig =
235-
getScratchConfigForCvt(op.getSrc().getType(), op.getType());
236-
bool isStMatrix = targetInfo.canUseStMatrix(
237-
op.getSrc().getType(), scratchConfig.repShape,
238-
scratchConfig.paddedRepShape, scratchConfig.order,
239-
/*swizzleByteSize=*/0);
240-
if (isStMatrix) {
241-
return failure();
242-
}
243-
244-
auto loc = op.getLoc();
245-
auto *ctx = op.getContext();
246-
auto srcTy = op.getSrc().getType();
247-
auto dstTy = op.getType();
248-
249-
// Remove the kBlock dimension from the layout as it's the identity in the
250-
// cvt
251-
auto srcLayout = toLinearLayout(srcTy);
252-
auto dstLayout = toLinearLayout(dstTy);
253-
auto kReg = str_attr("register");
254-
auto kLane = str_attr("lane");
255-
auto kWarp = str_attr("warp");
256-
srcLayout = srcLayout.sublayout({kReg, kLane, kWarp},
257-
to_vector(srcLayout.getOutDimNames()));
258-
dstLayout = dstLayout.sublayout({kReg, kLane, kWarp},
259-
to_vector(dstLayout.getOutDimNames()));
260-
261-
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
262-
auto smemBase =
263-
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
264-
auto inVals = unpackLLElements(loc, src, rewriter);
265-
auto outVals = transferWithinBlockSwizzlingImpl(
266-
loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase);
267-
268-
Value result =
269-
packLLElements(loc, getTypeConverter(), outVals, rewriter, dstTy);
270-
rewriter.replaceOp(op, result);
271-
return success();
272-
}
273-
274113
LogicalResult transferWithinBlock(ConvertLayoutOp op,
275114
const LinearLayout &srcLayout,
276115
const LinearLayout &dstLayout,
@@ -279,9 +118,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
279118
assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
280119

281120
// Try to use swizzling to implement the conversion
282-
// HACK Remove once AMD tests pass for the swizzling path
283-
if (targetInfo.isCuda() && succeeded(transferWithinBlockSwizzling(
284-
op, adaptor.getSrc(), rewriter))) {
121+
if (succeeded(transferWithinBlockSwizzling(op, adaptor.getSrc(), targetInfo,
122+
getTypeConverter(), rewriter))) {
285123
return success();
286124
}
287125

0 commit comments

Comments
 (0)