Skip to content

Commit 9feee06

Browse files
victor-edswhitneywhtsang
authored andcommitted
[TritonIntelGPUToLLVM] Adapt layout conversion to new LL interface
Replace `divideRight` calls with `quotient` calls and simplify code following upstream model. Signed-off-by: victor-eds <[email protected]>
1 parent def104e commit 9feee06

File tree

2 files changed

+73
-107
lines changed

2 files changed

+73
-107
lines changed

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1752,7 +1752,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
17521752
// CHECK-NOT: llvm.store
17531753
// CHECK-NOT: llvm.load
17541754
// CHECK: llvm.insertvalue
1755-
// CHECK: llvm.extractvalue
17561755
tt.func public @convert_single_element() attributes {noinline = false} {
17571756
%cst = arith.constant dense<1.000000e+03> : tensor<1xf32, #blocked1>
17581757
%0 = triton_gpu.convert_layout %cst : tensor<1xf32, #blocked1> -> tensor<1xf32, #blocked>

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 73 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -455,9 +455,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
455455
StringAttr kBlock = str_attr("block");
456456

457457
LinearLayout comp = dstLayout.invertAndCompose(srcLayout);
458-
std::optional<LinearLayout> conversion = comp.divideRight(
459-
LinearLayout::identity1D(comp.getInDimSize(kWarp), kWarp, kWarp) *
460-
LinearLayout::identity1D(comp.getInDimSize(kBlock), kBlock, kBlock));
458+
std::optional<LinearLayout> conversion =
459+
comp.quotient(kBlock)->quotient(kWarp);
461460
assert(conversion && "Expecting valid conversion");
462461
// Expected conversion is:
463462
// - register=1 -> (0, 1)
@@ -516,85 +515,87 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
516515
const auto &shape = op.getType().getShape();
517516
auto srcTy = op.getSrc().getType();
518517
auto dstTy = op.getType();
519-
std::optional<LinearLayout> srcLayout =
520-
toLinearLayout(shape, srcTy.getEncoding());
521-
std::optional<LinearLayout> dstLayout =
522-
toLinearLayout(shape, dstTy.getEncoding());
523-
if (!srcLayout.has_value() || !dstLayout.has_value()) {
524-
return failure();
525-
}
526518

527-
// There are four cases to handle.
528-
//
529-
// 1. Transfer between values in the same thread, in which case we simply
530-
// reorder the elements of adaptor.getSrc().
531-
// 2. Transfer between values in the same warp, in which case we try to
532-
// move values using warp shuffles, though if the pattern is complicated
533-
// enough we may fall back to using shared memory (case 3).
534-
// 3. Transfer between values in the same CTA, in which case we move values
535-
// through shared memory.
536-
// 4. Transfer between values in different CTAs, in which case we move
537-
// values through distributed shared memory.
538-
//
539-
// We can tell which case we're in by examining `conversion`.
540-
// For example, if the block -> block mapping is an identity layout: {1, 2,
541-
// 4, ...}, then there's no movement between data in different CTAs, and we
542-
// know we're not in case 4.
543-
if (cvtReordersRegisters(srcTy, dstTy)) { // Case 1.
544-
return transferWithinThread(op, *srcLayout, *dstLayout, adaptor,
545-
rewriter);
519+
auto conversion = minimalCvtLayout(srcTy, dstTy);
520+
if (!conversion.has_value()) {
521+
return rewriter.notifyMatchFailure(
522+
op, "NYI. srcTy and/or dstTy don't implement LLs yet");
546523
}
524+
LinearLayout srcLayout =
525+
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
526+
LinearLayout dstLayout =
527+
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
547528

548-
if (cvtNeedsWarpShuffle(srcTy, dstTy)) { // Case 2.
549-
return transferWithinLane(op, *srcLayout, *dstLayout, adaptor, rewriter);
550-
}
529+
StringAttr kBlock = str_attr("block");
530+
StringAttr kWarp = str_attr("warp");
531+
StringAttr kLane = str_attr("lane");
532+
StringAttr kRegister = str_attr("register");
551533

552-
// TODO: match transferWithinBlockOrGroup from
553-
// TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
554-
return transferWithinBlockGroup(op, *srcLayout, *dstLayout, adaptor,
555-
rewriter);
534+
assert(to_vector(conversion->getInDimNames()) ==
535+
to_vector(conversion->getOutDimNames()));
536+
auto dims = conversion->getInDimNames();
537+
if (llvm::is_contained(dims, str_attr("block"))) {
538+
// Case 1: Transfer between values in different CTAs.
539+
// This requires moving values through distributed shared memory.
540+
return rewriter.notifyMatchFailure(
541+
op, "NYI: Transfer between different CTAs");
542+
} else if (llvm::is_contained(dims, str_attr("warp"))) {
543+
return rewriter.notifyMatchFailure(
544+
op, "NYI: Transfer between different warps");
545+
} else if (llvm::is_contained(dims, str_attr("lane"))) {
546+
// Case 2: Transfer between values in the same CTA, in which case we move
547+
// values through shared memory.
548+
// If the operation is a supported sub-group shuffle, perform via shuffle
549+
// operations.
550+
if (isSubGroupShuffle(srcLayout, dstLayout) &&
551+
isSupportedSubGroupShuffle(op, adaptor)) {
552+
performSubGroupShuffle(op, srcLayout, dstLayout, adaptor, rewriter);
553+
return success();
554+
}
555+
// If the operation is a supported sub-group transposition, perform via
556+
// SLM.
557+
if (isSubGroupTranspose(srcLayout, dstLayout) &&
558+
isSupportedSubGroupTranspose(op, adaptor)) {
559+
performSubGroupTranspose(op, srcLayout, dstLayout, adaptor, rewriter);
560+
return success();
561+
}
562+
// TODO(jlebar): Implement me.
563+
return failure();
564+
} else if (llvm::is_contained(dims, str_attr("register"))) {
565+
// Case 4. Transfer between values in the same thread, in which case we
566+
// simply reorder the elements of adaptor.getSrc().
567+
return transferWithinThread(
568+
op, dstLayout.getFreeVariableMasks()[kRegister],
569+
dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter);
570+
} else {
571+
// The two layouts are equivalent. We should probably remove these in
572+
// RemoveLayoutConversion.
573+
rewriter.replaceOp(op, adaptor.getSrc());
574+
return success();
575+
}
556576
}
557577

558578
LogicalResult
559-
transferWithinThread(ConvertLayoutOp op, const LinearLayout &srcLayout,
560-
const LinearLayout &dstLayout, OpAdaptor adaptor,
579+
transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs,
580+
const LinearLayout &conversion, OpAdaptor adaptor,
561581
ConversionPatternRewriter &rewriter) const {
562582
MLIRContext *ctx = op.getContext();
563583
auto loc = op.getLoc();
564584
StringAttr kRegister = str_attr("register");
565-
StringAttr kLane = str_attr("lane");
566-
StringAttr kWarp = str_attr("warp");
567-
StringAttr kBlock = str_attr("block");
568-
569-
// There are three possible cases:
570-
//
571-
// 1. `srcLayout` has the same number of registers as `dstLayout`.
572-
// 2. `srcLayout` has fewer registers than `dstLayout`.
573-
// 3. `srcLayout` has more registers than `dstLayout`.
574-
//
575-
// In the second case `srcLayout . dstLayout^-1` is not surjective
576-
// because not all destination registers are covered.
577-
// Since the goal is to cover all of the destination
578-
// registers, we can instead use `dstLayout . srcLayout^-1`.
579-
LinearLayout conversion = dstLayout.invertAndCompose(srcLayout);
580-
auto dstToSrc = conversion.divideRight(
581-
LinearLayout::identity1D(conversion.getInDimSize(kLane), kLane, kLane) *
582-
LinearLayout::identity1D(conversion.getInDimSize(kWarp), kWarp, kWarp) *
583-
LinearLayout::identity1D(conversion.getInDimSize(kBlock), kBlock,
584-
kBlock));
585-
586585
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
587-
assert(ArrayRef(to_vector(dstToSrc->getInDimNames())) ==
588-
ArrayRef{kRegister});
589-
assert(ArrayRef(to_vector(dstToSrc->getOutDimNames())) ==
590-
ArrayRef{kRegister});
591586

592587
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
593-
SmallVector<Value> outVals;
594-
outVals.resize(dstToSrc->getInDimSize(kRegister));
595-
for (int i = 0; i < dstToSrc->getInDimSize(kRegister); i++) {
596-
auto srcIdx = dstToSrc->apply({{kRegister, i}});
597-
outVals[i] = inVals[srcIdx.begin()->second];
588+
SmallVector<Value> outVals(numRegs);
589+
for (int i = 0; i < outVals.size(); i++) {
590+
// Remove free masks from the register index
591+
// For example, if idx = 0b00111, and masks = 0b00100, then we get
592+
// 0b00011. It means that register 7 (0b111) has the same value as
593+
// register 3 (0b011).
594+
auto idx = i & (~regMasks);
595+
auto srcIdx = conversion.hasInDim(kRegister)
596+
? conversion.apply({{kRegister, idx}}).begin()->second
597+
: idx;
598+
outVals[i] = inVals[srcIdx];
598599
}
599600
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
600601
op.getType());
@@ -611,9 +612,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
611612
StringAttr kBlock = str_attr("block");
612613

613614
LinearLayout comp = dstLayout.invertAndCompose(srcLayout);
614-
std::optional<LinearLayout> conversion = comp.divideRight(
615-
LinearLayout::identity1D(comp.getInDimSize(kWarp), kWarp, kWarp) *
616-
LinearLayout::identity1D(comp.getInDimSize(kBlock), kBlock, kBlock));
615+
std::optional<LinearLayout> conversion =
616+
comp.quotient(kBlock)->quotient(kWarp);
617617
assert(conversion && "Expecting valid conversion");
618618
// TODO: Support more kind of shuffles.
619619
// Expected conversion is:
@@ -667,11 +667,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
667667
StringAttr kWarp = str_attr("warp");
668668
StringAttr kBlock = str_attr("block");
669669
LinearLayout comp = dstLayout.invertAndCompose(srcLayout);
670-
std::optional<LinearLayout> conversion = comp.divideRight(
671-
LinearLayout::identity1D(comp.getInDimSize(kWarp), kWarp, kWarp) *
672-
LinearLayout::identity1D(comp.getInDimSize(kBlock), kBlock, kBlock));
673-
assert(conversion && "Expecting valid layout");
674-
int32_t subGroupSize = conversion->getOutDimSize(kLane);
670+
LinearLayout conversion = *comp.quotient(kBlock)->quotient(kWarp);
671+
int32_t subGroupSize = conversion.getOutDimSize(kLane);
675672

676673
Location loc = op.getLoc();
677674

@@ -772,28 +769,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
772769
.Default(false);
773770
}
774771

775-
LogicalResult transferWithinLane(ConvertLayoutOp op,
776-
const LinearLayout &srcLayout,
777-
const LinearLayout &dstLayout,
778-
OpAdaptor adaptor,
779-
ConversionPatternRewriter &rewriter) const {
780-
// If the operation is a supported sub-group shuffle, perform via shuffle
781-
// operations.
782-
if (isSubGroupShuffle(srcLayout, dstLayout) &&
783-
isSupportedSubGroupShuffle(op, adaptor)) {
784-
performSubGroupShuffle(op, srcLayout, dstLayout, adaptor, rewriter);
785-
return success();
786-
}
787-
// If the operation is a supported sub-group transposition, perform via SLM.
788-
if (isSubGroupTranspose(srcLayout, dstLayout) &&
789-
isSupportedSubGroupTranspose(op, adaptor)) {
790-
performSubGroupTranspose(op, srcLayout, dstLayout, adaptor, rewriter);
791-
return success();
792-
}
793-
// TODO(jlebar): Implement me.
794-
return failure();
795-
}
796-
797772
bool isValidTypeForSubGroupTranspose(Type type) const {
798773
return TypeSwitch<Type, bool>(type)
799774
.Case([](IntegerType intTy) {
@@ -967,14 +942,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
967942
}
968943
return unwrapFromVectors(loc, transposedVecs, rewriter);
969944
}
970-
971-
LogicalResult
972-
transferWithinBlockGroup(ConvertLayoutOp op, const LinearLayout &srcLayout,
973-
const LinearLayout &dstLayout, OpAdaptor adaptor,
974-
ConversionPatternRewriter &rewriter) const {
975-
// TODO(jlebar): Implement me.
976-
return failure();
977-
}
978945
};
979946

980947
} // namespace

0 commit comments

Comments
 (0)