Skip to content

Commit 9cf9c63

Browse files
Sync ConvertLayoutOpToLLVM.cpp from upstream (#2615)
`lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp` is now identical to upstream. Signed-off-by: Whitney Tsang <[email protected]>
1 parent c637c07 commit 9cf9c63

File tree

1 file changed

+64
-21
lines changed

1 file changed

+64
-21
lines changed

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -288,60 +288,71 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
288288
return rewriter.notifyMatchFailure(
289289
op, "NYI. srcTy and/or dstTy don't implement LLs yet");
290290
}
291+
LinearLayout srcLayout =
292+
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
293+
LinearLayout dstLayout =
294+
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
295+
296+
StringAttr kBlock = str_attr("block");
297+
StringAttr kWarp = str_attr("warp");
298+
StringAttr kLane = str_attr("lane");
299+
StringAttr kRegister = str_attr("register");
291300

292301
assert(to_vector(conversion->getInDimNames()) ==
293302
to_vector(conversion->getOutDimNames()));
294303
auto dims = conversion->getInDimNames();
295-
if (llvm::is_contained(dims, str_attr("block"))) {
304+
if (llvm::is_contained(dims, kBlock)) {
296305
// Case 1: Transfer between values in different CTAs.
297306
// This requires moving values through distributed shared memory.
298307
return rewriter.notifyMatchFailure(
299308
op, "NYI: Transfer between different CTAs");
300-
} else if (llvm::is_contained(dims, str_attr("warp"))) {
309+
} else if (llvm::is_contained(dims, kWarp)) {
301310
// Case 2: Transfer between values in the same CTA, in which case we move
302311
// values through shared memory.
303-
LinearLayout srcLayout =
304-
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
305-
LinearLayout dstLayout =
306-
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
307312
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
308-
} else if (llvm::is_contained(dims, str_attr("lane"))) {
313+
} else if (llvm::is_contained(dims, kLane)) {
309314
// Case 3. Transfer between values in the same warp, in which case we try
310315
// to move values using warp shuffles, though if the pattern is
311316
// complicated enough we may fall back to using shared memory
312317
// TODO(Keren): implement warp shuffle instead of using the general
313318
// approach that uses shared memory
314-
LinearLayout srcLayout =
315-
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
316-
LinearLayout dstLayout =
317-
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
318319
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
319-
} else if (llvm::is_contained(dims, str_attr("register"))) {
320+
} else if (llvm::is_contained(dims, kRegister) ||
321+
dstLayout.getInDimSize(kRegister) !=
322+
srcLayout.getInDimSize(kRegister)) {
320323
// Case 4. Transfer between values in the same thread, in which case we
321324
// simply reorder the elements of adaptor.getSrc().
322-
return transferWithinThread(op, *conversion, adaptor, rewriter);
325+
return transferWithinThread(
326+
op, dstLayout.getFreeVariableMasks()[kRegister],
327+
dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter);
323328
} else {
324-
// The two layouts are equivalent. We should probably remove these in
325-
// RemoveLayoutConversion.
329+
// Cast 5. The two layouts are equivalent. We should probably remove
330+
// these in RemoveLayoutConversion.
326331
rewriter.replaceOp(op, adaptor.getSrc());
327332
return success();
328333
}
329334
}
330335

331336
LogicalResult
332-
transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion,
333-
OpAdaptor adaptor,
337+
transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs,
338+
const LinearLayout &conversion, OpAdaptor adaptor,
334339
ConversionPatternRewriter &rewriter) const {
335340
MLIRContext *ctx = op.getContext();
336341
auto loc = op.getLoc();
337342
StringAttr kRegister = str_attr("register");
338343
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
339344

340345
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
341-
SmallVector<Value> outVals;
342-
outVals.resize(conversion.getInDimSize(kRegister));
343-
for (int i = 0; i < conversion.getInDimSize(kRegister); i++) {
344-
auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second;
346+
SmallVector<Value> outVals(numRegs);
347+
for (int i = 0; i < outVals.size(); i++) {
348+
// Remove free masks from the register index
349+
// For example, if idx = 0b00111, and masks = 0b00100, then we get
350+
// 0b00011. It means that register 7 (0b111) has the same value as
351+
// register 3 (0b011).
352+
auto idx = i & (~regMasks);
353+
auto srcIdx = conversion.hasInDim(kRegister)
354+
? conversion.apply({{kRegister, idx}}).begin()->second
355+
: idx;
345356
outVals[i] = inVals[srcIdx];
346357
}
347358
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
@@ -372,6 +383,22 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
372383
}
373384
return true;
374385
}
386+
if (auto dotOperand = dyn_cast<DotOperandEncodingAttr>(layout)) {
387+
if (auto nvidiaMma =
388+
dyn_cast<NvidiaMmaEncodingAttr>(dotOperand.getParent())) {
389+
if (product(getCTAsPerCGA(nvidiaMma)) > 1) {
390+
return false;
391+
}
392+
if (useLegacyMMAConversion) {
393+
return false;
394+
}
395+
// FIXME [Dot LL]
396+
// Enabling LL path for buggy kWidth path
397+
bool largeKWidth =
398+
dotOperand.getKWidth() * dstTy.getElementTypeBitWidth() > 64;
399+
return largeKWidth && nvidiaMma.isAmpere();
400+
}
401+
}
375402
if (isa<BlockedEncodingAttr>(layout)) {
376403
return true;
377404
}
@@ -431,6 +458,22 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
431458
}
432459
}
433460

461+
// FIXME [Dot LL]
462+
// We know it's just for largeKWidth case in Ampere
463+
// In this case, we need to pack the outputs into i32
464+
if (isa<DotOperandEncodingAttr>(dstTy.getEncoding())) {
465+
auto concat = [&](Value a, Value b) {
466+
return or_(zext(i32_ty, bitcast(a, i16_ty)),
467+
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
468+
};
469+
470+
SmallVector<Value> outVals32(outVals.size() / 2);
471+
for (int i = 0; i < outVals32.size(); ++i) {
472+
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
473+
}
474+
outVals = outVals32;
475+
}
476+
434477
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
435478
op.getType());
436479
rewriter.replaceOp(op, result);

0 commit comments

Comments
 (0)