Skip to content

Commit a0b4c91

Browse files
Merge commit '40f71635fecd20059a36b78be303b53058215310'
2 parents 81946f9 + 40f7163 commit a0b4c91

File tree

16 files changed

+295
-255
lines changed

16 files changed

+295
-255
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,18 @@ void storeDistributedToShared(
562562
RewriterBase &rewriter, const TargetInfoBase &target,
563563
std::pair<size_t, Type> *const llvmOpCount = nullptr);
564564

565+
// Close cousin of lowerLdStMatrix in MemoryOpToLLVM.cpp
566+
// We might want to merge them at some point, but having to support
567+
// ldmatrix.trans makes the code in lowerLdStMatrix a bit specific
568+
// Lowers to st when valArrays is empty, and to ld when it is not,
569+
// and returns the output values.
570+
SmallVector<Value>
571+
lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
572+
ArrayRef<Value> valsArray, // Input for store, output for load
573+
Type llvmElemTy, Value smemBase,
574+
ConversionPatternRewriter &rewriter,
575+
const TargetInfoBase &targetInfo);
576+
565577
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
566578
RewriterBase &rewriter);
567579

include/triton/Tools/LinearLayout.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,7 @@ class ColumnAction {
806806
SmallVector<size_t> action;
807807
StringAttr inDim;
808808
size_t inSizeLog2;
809-
bool isIdentity = true;
809+
bool m_isIdentity = true;
810810

811811
public:
812812
ColumnAction() = default;
@@ -817,7 +817,8 @@ class ColumnAction {
817817
assert(it == action.end() || *it < inSizeLog2);
818818
// In many cases the action will be the identity, so we save that as an
819819
// early return
820-
isIdentity = action.size() == inSizeLog2 && llvm::is_sorted(action);
820+
m_isIdentity = action.size() == inSizeLog2 &&
821+
llvm::equal(action, llvm::seq<size_t>(action.size()));
821822
}
822823

823824
// Act on the columns of a layout
@@ -837,6 +838,9 @@ class ColumnAction {
837838
// Inverse of the action
838839
ColumnAction inverse() const;
839840

841+
// Returns true if the action is the identity
842+
bool isIdentity() const { return m_isIdentity; }
843+
840844
std::string toString() const;
841845
};
842846

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 92 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -114,190 +114,65 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
114114
return success();
115115
}
116116

117-
std::pair<int, ColumnAction> largestVectorisation(MLIRContext *ctx,
118-
const LinearLayout &cvt,
119-
int bitwidth) const {
120-
// Find the largest vectorisation we can use:
121-
StringAttr kReg = str_attr("register");
122-
StringAttr kOffset = str_attr("offset");
123-
LinearLayout quot;
124-
LinearLayout tile;
125-
ColumnAction permutation;
126-
for (int v = 128 / bitwidth; v >= 1; v /= 2) {
127-
tile = LinearLayout::identity1D(v, kReg, kOffset);
128-
auto maybePerm = regPermForDivide(cvt, tile, /*left=*/true);
129-
if (!maybePerm) {
130-
continue;
131-
}
132-
permutation = *maybePerm;
133-
auto newCvt = permutation.apply(cvt);
134-
auto maybeQuot = divideLeft(newCvt, tile);
135-
if (!maybeQuot) {
136-
continue;
137-
}
138-
return {v, permutation};
139-
}
140-
llvm_unreachable("No vectorisation found");
141-
}
142-
143-
// Close cousin of lowerLdStMatrix in MemoryOpToLLVM.cpp
144-
// We might want to merge them at some point, but having to support
145-
// ldmatrix.trans makes the code in lowerLdStMatrix a bit specific
146-
// Lowers to st when valArrays is empty, and to ld when it is not,
147-
// and returns the output values.
148-
SmallVector<Value>
149-
lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
150-
int elemsPerVec,
151-
ArrayRef<Value> valsArray, // Input for store, output for load
152-
Type llvmElemTy, Value smemBase,
153-
ConversionPatternRewriter &rewriter) const {
154-
auto vals = to_vector(valsArray);
155-
bool isStore = !vals.empty();
117+
SmallVector<Value> transferWithinBlockSwizzlingImpl(
118+
Location loc, ConversionPatternRewriter &rewriter,
119+
const LinearLayout &srcLayout, const LinearLayout &dstLayout,
120+
ArrayRef<Value> inVals, Type llvmElemTy, Value smemBase) const {
121+
auto *ctx = rewriter.getContext();
156122
auto b = TritonLLVMOpBuilder(loc, rewriter);
157-
auto smemPtrTy = ptr_ty(ctx, 3);
158-
auto kReg = str_attr("register");
159-
auto kLane = str_attr("lane");
160-
auto kWarp = str_attr("warp");
161-
auto kOffset = str_attr("offset");
162-
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
163-
164-
auto [vec, permutation] = largestVectorisation(ctx, cvt, bitwidth);
165-
assert(vec >= elemsPerVec);
166-
elemsPerVec = vec;
167-
168-
cvt = permutation.apply(cvt);
169-
if (isStore) {
170-
vals = permutation.apply(vals);
171-
}
172-
173-
auto tile = LinearLayout::identity1D(vec, kReg, kOffset);
174-
auto quot = *divideLeft(cvt, tile);
175-
LinearLayout reps = zerosLike(tile) * quot;
176-
177-
auto [nAdditive, permStrides] = actionAdditiveStrides(reps);
178-
reps = permStrides.apply(reps);
179-
if (isStore) {
180-
vals = permStrides.apply(vals);
181-
}
182-
183-
// PTX expects the address increments to be done in bytes
184-
// If we don't perform the computations in i8, the compiler would
185-
// have to divide the computation by bitwdith / 8 and then lift this
186-
// shl, which often it's not able to do.
187-
auto i8Tile =
188-
zerosLike(LinearLayout::identity1D(bitwidth / 8, kReg, kOffset));
189-
auto i8Reps = i8Tile * reps;
190-
191-
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
192-
auto regBaseI8 =
193-
applyLinearLayout(
194-
loc, rewriter, i8Reps,
195-
{{kReg, b.i32_val(0)}, {kLane, laneId}, {kWarp, warpId}})[0]
196-
.second;
197-
SmallVector<Value> outVals;
198-
for (int i = 0; i < cvt.getInDimSize(kReg); i += nAdditive) {
199-
auto regIdx = reps.apply({{kReg, i}, {kLane, 0}, {kWarp, 0}})[0].second;
200-
auto regIdxI8 = regIdx * (bitwidth / 8);
201-
Value offset = b.xor_(regBaseI8, b.i32_val(regIdxI8));
202-
for (int j = 0; j < nAdditive; j += elemsPerVec) {
203-
// all these constants will go as immediate values to LDS/STS
204-
auto regIdxAdd =
205-
reps.apply({{kReg, j}, {kLane, 0}, {kWarp, 0}})[0].second;
206-
auto regIdxAddI8 = regIdxAdd * (bitwidth / 8);
207-
Value innerOffset = b.add(offset, b.i32_val(regIdxAddI8));
208-
auto vecAddr = b.gep(smemPtrTy, i8_ty, smemBase, innerOffset,
209-
LLVM::GEPNoWrapFlags::inbounds);
210-
// Lezcano: Do we want to use getFreeVariableMasks for pred or nah?
211-
if (isStore) {
212-
Value valsVec = packLLVector(
213-
loc, ArrayRef<Value>(vals).slice(i + j, elemsPerVec), rewriter);
214-
targetInfo.storeDShared(rewriter, loc, vecAddr, std::nullopt, valsVec,
215-
/*pred=*/b.true_val());
216-
} else {
217-
Value valsVec =
218-
targetInfo.loadDShared(rewriter, loc, vecAddr, std::nullopt,
219-
vec_ty(llvmElemTy, elemsPerVec),
220-
/*pred=*/b.true_val());
221-
llvm::append_range(outVals, unpackLLVector(loc, valsVec, rewriter));
222-
}
123+
// We handle transformations recursively as they all need a preprocessing
124+
// and a postprocessing step.
125+
126+
// Handle pointer types as 64-bit integers
127+
if (isa<LLVM::LLVMPointerType>(llvmElemTy)) {
128+
auto llvmElemTyPtr = i64_ty;
129+
auto newInVals = llvm::to_vector(llvm::map_range(inVals, [&](Value v) {
130+
return b.ptrtoint(llvmElemTyPtr, v).getResult();
131+
}));
132+
auto outVals =
133+
transferWithinBlockSwizzlingImpl(loc, rewriter, srcLayout, dstLayout,
134+
newInVals, llvmElemTyPtr, smemBase);
135+
for (auto &v : outVals) {
136+
v = b.inttoptr(llvmElemTy, v);
223137
}
138+
return outVals;
224139
}
225140

226-
// Permute the values back if we are loading
227-
if (!isStore) {
228-
auto invPermStrides = permStrides.inverse();
229-
outVals = invPermStrides.apply(outVals);
230-
auto invPerm = permutation.inverse();
231-
outVals = invPerm.apply(outVals);
232-
}
233-
return outVals;
234-
}
235-
236-
LogicalResult
237-
transferWithinBlockSwizzling(ConvertLayoutOp op, Value src,
238-
ConversionPatternRewriter &rewriter) const {
239-
// Fallback for now to standard lowering if it can use stmatrix
240-
auto scratchConfig =
241-
getScratchConfigForCvt(op.getSrc().getType(), op.getType());
242-
bool isStMatrix = targetInfo.canUseStMatrix(
243-
op.getSrc().getType(), scratchConfig.repShape,
244-
scratchConfig.paddedRepShape, scratchConfig.order,
245-
/*swizzleByteSize=*/0);
246-
if (isStMatrix) {
247-
return failure();
248-
}
249-
250-
auto loc = op.getLoc();
251-
auto *ctx = op.getContext();
252-
auto b = TritonLLVMOpBuilder(loc, rewriter);
253-
auto srcTy = op.getSrc().getType();
254-
auto dstTy = op.getType();
255-
auto bitwidth = isa<PointerType>(srcTy.getElementType())
256-
? kPtrBitWidth
257-
: srcTy.getElementTypeBitWidth();
258-
259-
auto srcLayout = toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
260-
auto dstLayout = toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
261-
auto origDstLayout = dstLayout;
262-
263-
// We remove the Block dimension from the layout as it's the identity in the
264-
// cvt
265-
auto kRegister = str_attr("register");
266-
auto kLane = str_attr("lane");
267-
auto kWarp = str_attr("warp");
268-
srcLayout = srcLayout.sublayout({kRegister, kLane, kWarp},
269-
to_vector(srcLayout.getOutDimNames()));
270-
dstLayout = dstLayout.sublayout({kRegister, kLane, kWarp},
271-
to_vector(dstLayout.getOutDimNames()));
272-
273141
// Handle sub-byte elements like i1
274-
auto inVals = unpackLLElements(loc, src, rewriter);
275-
276-
bool isSubByte = bitwidth < 8;
277-
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
278-
if (isSubByte) {
142+
if (llvmElemTy.getIntOrFloatBitWidth() < 8) {
279143
// Upcast to i8
280-
bitwidth = 8;
281-
llvmElemTy = i8_ty;
282-
for (auto &v : inVals) {
283-
v = b.zext(llvmElemTy, v);
284-
}
285-
}
286-
bool isPtr = isa<PointerType>(srcTy.getElementType());
287-
if (isPtr) {
288-
llvmElemTy =
289-
getTypeConverter()->convertType(IntegerType::get(ctx, kPtrBitWidth));
290-
for (auto &v : inVals) {
291-
v = b.ptrtoint(llvmElemTy, v);
144+
auto i8ElemTy = i8_ty;
145+
auto newInVals = llvm::to_vector(llvm::map_range(
146+
inVals, [&](Value v) { return b.zext(i8ElemTy, v).getResult(); }));
147+
auto outVals = transferWithinBlockSwizzlingImpl(
148+
loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
149+
for (auto &v : outVals) {
150+
v = b.trunc(llvmElemTy, v);
292151
}
152+
return outVals;
293153
}
294154

295-
// Remove register broadcast from src and dst and input values
155+
// Remove broadcasting in src
296156
auto removeBroadcastSrc = actionRemoveBroadcastedRegs(srcLayout);
297-
srcLayout = removeBroadcastSrc.apply(srcLayout);
298-
inVals = removeBroadcastSrc.apply(inVals);
299-
dstLayout = actionRemoveBroadcastedRegs(dstLayout).apply(dstLayout);
157+
if (!removeBroadcastSrc.isIdentity()) {
158+
auto prmtSrc = removeBroadcastSrc.apply(srcLayout);
159+
auto newInVals = removeBroadcastSrc.apply(inVals);
160+
return transferWithinBlockSwizzlingImpl(loc, rewriter, prmtSrc, dstLayout,
161+
newInVals, llvmElemTy, smemBase);
162+
}
300163

164+
// Remove broadcasting in dst
165+
auto removeBroadcastDst = actionRemoveBroadcastedRegs(dstLayout);
166+
if (!removeBroadcastDst.isIdentity()) {
167+
auto prmtDst = removeBroadcastDst.apply(dstLayout);
168+
auto outVals = transferWithinBlockSwizzlingImpl(
169+
loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase);
170+
return broadcastAs(outVals, dstLayout);
171+
}
172+
173+
// At this point we have a type that's at least 8-bit
174+
// and we don't have broadcasting in the registers
175+
auto bitwidth = llvmElemTy.getIntOrFloatBitWidth();
301176
auto smem = optimalSwizzling(srcLayout, dstLayout, bitwidth);
302177

303178
// Extract reps from smem
@@ -314,7 +189,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
314189
auto permStore =
315190
regPermForDivide(totalStoreCvt, reps, /*left=*/false).value();
316191
totalStoreCvt = permStore.apply(totalStoreCvt);
317-
inVals = permStore.apply(inVals);
192+
auto permutedInVals = permStore.apply(inVals);
318193
auto permLoad =
319194
regPermForDivide(totalLoadCvt, reps, /*left=*/false).value();
320195
totalLoadCvt = permLoad.apply(totalLoadCvt);
@@ -326,49 +201,68 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
326201
storeCvt = storeCvt.reshapeOuts({{kOffset, storeCvt.getTotalOutDimSize()}});
327202
loadCvt = loadCvt.reshapeOuts({{kOffset, loadCvt.getTotalOutDimSize()}});
328203

329-
Value smemBase =
330-
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
331-
332204
auto tileSize = storeCvt.getInDimSize(kReg);
333205

334-
assert(inVals.size() == tileSize * nReps);
206+
assert(permutedInVals.size() == tileSize * nReps);
335207
SmallVector<Value> outVals;
336-
auto elemsPerVec = smem.getInDimSize(str_attr("vector"));
337208
for (int i = 0; i < nReps; ++i) {
338209
if (i > 0)
339210
b.barrier();
340211

341-
auto tileInVals = ArrayRef<Value>(inVals).slice(i * tileSize, tileSize);
212+
auto tileInVals =
213+
ArrayRef<Value>(permutedInVals).slice(i * tileSize, tileSize);
342214
// Store
343-
lowerLdStShared(loc, ctx, storeCvt, elemsPerVec, tileInVals, llvmElemTy,
344-
smemBase, rewriter);
215+
lowerLdStShared(loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
216+
rewriter, targetInfo);
345217
b.barrier();
346218
// Load
347219
SmallVector<Value> tileOutVals = lowerLdStShared(
348-
loc, ctx, loadCvt, elemsPerVec, {}, llvmElemTy, smemBase, rewriter);
220+
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, rewriter, targetInfo);
349221
llvm::append_range(outVals, tileOutVals);
350222
}
351223

352224
// Undo the permLoad used to divideRight
353225
outVals = permLoad.inverse().apply(outVals);
226+
return outVals;
227+
}
354228

355-
// Unwrap sub-byte elements if necessary
356-
if (isSubByte) {
357-
auto llvmElemTyOrig =
358-
getTypeConverter()->convertType(srcTy.getElementType());
359-
for (auto &v : outVals) {
360-
v = b.trunc(llvmElemTyOrig, v);
361-
}
362-
} else if (isPtr) {
363-
auto llvmElemTyOrig =
364-
getTypeConverter()->convertType(srcTy.getElementType());
365-
for (auto &v : outVals) {
366-
v = b.inttoptr(llvmElemTyOrig, v);
367-
}
229+
LogicalResult
230+
transferWithinBlockSwizzling(ConvertLayoutOp op, Value src,
231+
ConversionPatternRewriter &rewriter) const {
232+
// Fallback for now to standard lowering if it can use stmatrix
233+
auto scratchConfig =
234+
getScratchConfigForCvt(op.getSrc().getType(), op.getType());
235+
bool isStMatrix = targetInfo.canUseStMatrix(
236+
op.getSrc().getType(), scratchConfig.repShape,
237+
scratchConfig.paddedRepShape, scratchConfig.order,
238+
/*swizzleByteSize=*/0);
239+
if (isStMatrix) {
240+
return failure();
368241
}
369242

370-
// Undo the removeBroadcastSrc
371-
outVals = broadcastAs(outVals, origDstLayout);
243+
auto loc = op.getLoc();
244+
auto *ctx = op.getContext();
245+
auto srcTy = op.getSrc().getType();
246+
auto dstTy = op.getType();
247+
248+
// Remove the kBlock dimension from the layout as it's the identity in the
249+
// cvt
250+
auto srcLayout = toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
251+
auto dstLayout = toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
252+
auto kReg = str_attr("register");
253+
auto kLane = str_attr("lane");
254+
auto kWarp = str_attr("warp");
255+
srcLayout = srcLayout.sublayout({kReg, kLane, kWarp},
256+
to_vector(srcLayout.getOutDimNames()));
257+
dstLayout = dstLayout.sublayout({kReg, kLane, kWarp},
258+
to_vector(dstLayout.getOutDimNames()));
259+
260+
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
261+
auto smemBase =
262+
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
263+
auto inVals = unpackLLElements(loc, src, rewriter);
264+
auto outVals = transferWithinBlockSwizzlingImpl(
265+
loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase);
372266

373267
Value result =
374268
packLLElements(loc, getTypeConverter(), outVals, rewriter, dstTy);

0 commit comments

Comments
 (0)