Skip to content

Commit 986d273

Browse files
committed
Check transposition feasibility before performing it
1 parent d386951 commit 986d273

File tree

1 file changed

+77
-55
lines changed

1 file changed

+77
-55
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 77 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ struct ConvertLayoutOpConversion
434434

435435
struct ConvertLayoutOpUsingLinearLayoutsConversion
436436
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
437-
constexpr static unsigned maxSubGroupTransposeWidth = 64;
437+
constexpr static unsigned minSubGroupTransposeWidth = 8;
438438

439439
// Set benefit to 2 so that this pattern applies before other convert-layout
440440
// conversions. TODO(jlebar): Eventually we want this to be the only pattern.
@@ -557,14 +557,44 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
557557
return success();
558558
}
559559

560+
bool isSupportedSubGroupTranspose(ConvertLayoutOp op,
561+
OpAdaptor adaptor) const {
562+
auto srcType = cast<LLVM::LLVMStructType>(adaptor.getSrc().getType());
563+
ArrayRef<Type> body = srcType.getBody();
564+
auto mod = op->getParentOfType<ModuleOp>();
565+
// Only supporting sub_group_size^2 transpositions for now.
566+
if (body.size() !=
567+
mlir::triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod))
568+
return false;
569+
return TypeSwitch<Type, bool>(body.front())
570+
.Case([this](FloatType floatTy) {
571+
// Support via bitcasting to integer type.
572+
return isValidTypeForSubGroupTranspose(
573+
IntegerType::get(floatTy.getContext(), floatTy.getWidth()));
574+
})
575+
.Case([this](IntegerType intTy) {
576+
// Support via extending to supported type.
577+
return isValidTypeForSubGroupTranspose(intTy) ||
578+
intTy.getWidth() < minSubGroupTransposeWidth;
579+
})
580+
.Case([](LLVM::LLVMPointerType) {
581+
// Support via ptrtoint
582+
return true;
583+
})
584+
.Default([](auto) { return false; });
585+
}
586+
560587
LogicalResult transferWithinLane(ConvertLayoutOp op,
561588
const LinearLayout &srcLayout,
562589
const LinearLayout &dstLayout,
563590
OpAdaptor adaptor,
564591
ConversionPatternRewriter &rewriter) const {
565-
if (isSubGroupTranspose(srcLayout, dstLayout))
566-
return performSubGroupTranspose(op, srcLayout, dstLayout, adaptor,
567-
rewriter);
592+
// If the operation is a supported sub-group transposition, perform via SLM.
593+
if (isSubGroupTranspose(srcLayout, dstLayout) &&
594+
isSupportedSubGroupTranspose(op, adaptor)) {
595+
performSubGroupTranspose(op, srcLayout, dstLayout, adaptor, rewriter);
596+
return success();
597+
}
568598
return failure();
569599
}
570600

@@ -577,61 +607,56 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
577607
.Default([](auto) { return false; });
578608
}
579609

580-
LogicalResult
581-
performSubGroupTranspose(ConvertLayoutOp op, const LinearLayout &srcLayout,
582-
const LinearLayout &dstLayout, OpAdaptor adaptor,
583-
ConversionPatternRewriter &rewriter) const {
610+
void performSubGroupTranspose(ConvertLayoutOp op,
611+
const LinearLayout &srcLayout,
612+
const LinearLayout &dstLayout,
613+
OpAdaptor adaptor,
614+
ConversionPatternRewriter &rewriter) const {
615+
assert(isSubGroupTranspose(srcLayout, dstLayout) &&
616+
"Expecting sub-group transpose");
617+
assert(isSupportedSubGroupTranspose(op, adaptor) &&
618+
"Expecting supported sub-group transpose");
619+
584620
Location loc = op.getLoc();
585621

586622
SmallVector<Value> inVals =
587623
unpackLLElements(loc, adaptor.getSrc(), rewriter);
588624

589625
// TODO: Support multiples of sub_group_size
590626
auto mod = op->getParentOfType<ModuleOp>();
591-
if (inVals.size() !=
592-
mlir::triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod))
593-
return failure();
594627

595628
auto srcTy = cast<RankedTensorType>(op.getSrc().getType());
596-
Type origElemTy = srcTy.getElementType();
597-
598-
LogicalResult conversionRes =
599-
TypeSwitch<Type, LogicalResult>(origElemTy)
600-
.Case([&](FloatType floatTy) {
601-
// TODO: Support FP4.
602-
Type dstType = int_ty(floatTy.getWidth());
603-
if (!isValidTypeForSubGroupTranspose(dstType))
604-
return failure();
605-
llvm::transform(
606-
inVals, std::begin(inVals),
607-
[&](Value val) -> Value { return bitcast(val, dstType); });
608-
return success();
609-
})
610-
.Case([&](IntegerType intTy) {
611-
if (isValidTypeForSubGroupTranspose(intTy))
612-
return success();
613-
if (intTy.getWidth() > maxSubGroupTransposeWidth)
614-
return failure();
615-
// intTy.getWidth() < minSubGroupTransposeWidth
616-
Type dstType = i8_ty;
617-
llvm::transform(
618-
inVals, std::begin(inVals),
619-
[&](Value val) -> Value { return zext(dstType, val); });
620-
return success();
621-
})
622-
.Case([&](triton::PointerType) {
623-
Type dstType = i64_ty;
624-
assert(isValidTypeForSubGroupTranspose(dstType) &&
625-
"i64 type should be supported");
626-
llvm::transform(
627-
inVals, std::begin(inVals),
628-
[&](Value val) -> Value { return ptrtoint(dstType, val); });
629-
return success();
630-
})
631-
.Default([&](auto) { return failure(); });
632-
633-
if (failed(conversionRes))
634-
return conversionRes;
629+
Type origElemTy = inVals.front().getType();
630+
631+
TypeSwitch<Type>(origElemTy)
632+
.Case([&](FloatType floatTy) {
633+
// TODO: Support FP4.
634+
Type dstType = int_ty(floatTy.getWidth());
635+
assert(isValidTypeForSubGroupTranspose(dstType) &&
636+
"Expecting valid type");
637+
llvm::transform(inVals, std::begin(inVals), [&](Value val) -> Value {
638+
return bitcast(val, dstType);
639+
});
640+
})
641+
.Case([&](IntegerType intTy) {
642+
if (isValidTypeForSubGroupTranspose(intTy))
643+
return;
644+
assert(intTy.getWidth() < minSubGroupTransposeWidth &&
645+
"Expecting type to extend to i8");
646+
Type dstType = i8_ty;
647+
llvm::transform(inVals, std::begin(inVals), [&](Value val) -> Value {
648+
return zext(dstType, val);
649+
});
650+
})
651+
.Case([&](LLVM::LLVMPointerType) {
652+
Type dstType = i64_ty;
653+
assert(isValidTypeForSubGroupTranspose(dstType) &&
654+
"i64 type should be supported");
655+
llvm::transform(inVals, std::begin(inVals), [&](Value val) -> Value {
656+
return ptrtoint(dstType, val);
657+
});
658+
})
659+
.Default([](auto) { llvm_unreachable("Unsupported type"); });
635660

636661
SmallVector<Value> outVals =
637662
performSubGroupTranspose(loc, inVals, rewriter);
@@ -650,18 +675,15 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
650675
outVals, std::begin(outVals),
651676
[&](Value val) -> Value { return trunc(origElemTy, val); });
652677
})
653-
.Case([&](triton::PointerType ptrTy) {
654-
Type llvmPtrTy = getTypeConverter()->convertType(ptrTy);
655-
assert(llvmPtrTy && "Type conversion failed");
678+
.Case([&](LLVM::LLVMPointerType ptrTy) {
656679
llvm::transform(
657680
outVals, std::begin(outVals),
658-
[&](Value val) -> Value { return inttoptr(llvmPtrTy, val); });
681+
[&](Value val) -> Value { return inttoptr(ptrTy, val); });
659682
});
660683

661684
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,
662685
op.getType());
663686
rewriter.replaceOp(op, result);
664-
return success();
665687
}
666688

667689
VectorType

0 commit comments

Comments
 (0)