@@ -434,7 +434,7 @@ struct ConvertLayoutOpConversion
434434
435435struct ConvertLayoutOpUsingLinearLayoutsConversion
436436 : public ConvertOpToLLVMPattern<ConvertLayoutOp> {
437- constexpr static unsigned maxSubGroupTransposeWidth = 64 ;
437+ constexpr static unsigned minSubGroupTransposeWidth = 8 ;
438438
439439 // Set benefit to 2 so that this pattern applies before other convert-layout
440440 // conversions. TODO(jlebar): Eventually we want this to be the only pattern.
@@ -557,14 +557,44 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
557557 return success ();
558558 }
559559
560+ bool isSupportedSubGroupTranspose (ConvertLayoutOp op,
561+ OpAdaptor adaptor) const {
562+ auto srcType = cast<LLVM::LLVMStructType>(adaptor.getSrc ().getType ());
563+ ArrayRef<Type> body = srcType.getBody ();
564+ auto mod = op->getParentOfType <ModuleOp>();
565+ // Only supporting sub_group_size^2 transpositions for now.
566+ if (body.size () !=
567+ mlir::triton::gpu::TritonGPUDialect::getThreadsPerWarp (mod))
568+ return false ;
569+ return TypeSwitch<Type, bool >(body.front ())
570+ .Case ([this ](FloatType floatTy) {
571+ // Support via bitcasting to integer type.
572+ return isValidTypeForSubGroupTranspose (
573+ IntegerType::get (floatTy.getContext (), floatTy.getWidth ()));
574+ })
575+ .Case ([this ](IntegerType intTy) {
576+ // Support via extending to supported type.
577+ return isValidTypeForSubGroupTranspose (intTy) ||
578+ intTy.getWidth () < minSubGroupTransposeWidth;
579+ })
580+ .Case ([](LLVM::LLVMPointerType) {
581+ // Support via ptrtoint
582+ return true ;
583+ })
584+ .Default ([](auto ) { return false ; });
585+ }
586+
560587 LogicalResult transferWithinLane (ConvertLayoutOp op,
561588 const LinearLayout &srcLayout,
562589 const LinearLayout &dstLayout,
563590 OpAdaptor adaptor,
564591 ConversionPatternRewriter &rewriter) const {
565- if (isSubGroupTranspose (srcLayout, dstLayout))
566- return performSubGroupTranspose (op, srcLayout, dstLayout, adaptor,
567- rewriter);
592+ // If the operation is a supported sub-group transposition, perform via SLM.
593+ if (isSubGroupTranspose (srcLayout, dstLayout) &&
594+ isSupportedSubGroupTranspose (op, adaptor)) {
595+ performSubGroupTranspose (op, srcLayout, dstLayout, adaptor, rewriter);
596+ return success ();
597+ }
568598 return failure ();
569599 }
570600
@@ -577,61 +607,56 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
577607 .Default ([](auto ) { return false ; });
578608 }
579609
580- LogicalResult
581- performSubGroupTranspose (ConvertLayoutOp op, const LinearLayout &srcLayout,
582- const LinearLayout &dstLayout, OpAdaptor adaptor,
583- ConversionPatternRewriter &rewriter) const {
610+ void performSubGroupTranspose (ConvertLayoutOp op,
611+ const LinearLayout &srcLayout,
612+ const LinearLayout &dstLayout,
613+ OpAdaptor adaptor,
614+ ConversionPatternRewriter &rewriter) const {
615+ assert (isSubGroupTranspose (srcLayout, dstLayout) &&
616+ " Expecting sub-group transpose" );
617+ assert (isSupportedSubGroupTranspose (op, adaptor) &&
618+ " Expecting supported sub-group transpose" );
619+
584620 Location loc = op.getLoc ();
585621
586622 SmallVector<Value> inVals =
587623 unpackLLElements (loc, adaptor.getSrc (), rewriter);
588624
589625 // TODO: Support multiples of sub_group_size
590626 auto mod = op->getParentOfType <ModuleOp>();
591- if (inVals.size () !=
592- mlir::triton::gpu::TritonGPUDialect::getThreadsPerWarp (mod))
593- return failure ();
594627
595628 auto srcTy = cast<RankedTensorType>(op.getSrc ().getType ());
596- Type origElemTy = srcTy.getElementType ();
597-
598- LogicalResult conversionRes =
599- TypeSwitch<Type, LogicalResult>(origElemTy)
600- .Case ([&](FloatType floatTy) {
601- // TODO: Support FP4.
602- Type dstType = int_ty (floatTy.getWidth ());
603- if (!isValidTypeForSubGroupTranspose (dstType))
604- return failure ();
605- llvm::transform (
606- inVals, std::begin (inVals),
607- [&](Value val) -> Value { return bitcast (val, dstType); });
608- return success ();
609- })
610- .Case ([&](IntegerType intTy) {
611- if (isValidTypeForSubGroupTranspose (intTy))
612- return success ();
613- if (intTy.getWidth () > maxSubGroupTransposeWidth)
614- return failure ();
615- // intTy.getWidth() < minSubGroupTransposeWidth
616- Type dstType = i8_ty;
617- llvm::transform (
618- inVals, std::begin (inVals),
619- [&](Value val) -> Value { return zext (dstType, val); });
620- return success ();
621- })
622- .Case ([&](triton::PointerType) {
623- Type dstType = i64_ty;
624- assert (isValidTypeForSubGroupTranspose (dstType) &&
625- " i64 type should be supported" );
626- llvm::transform (
627- inVals, std::begin (inVals),
628- [&](Value val) -> Value { return ptrtoint (dstType, val); });
629- return success ();
630- })
631- .Default ([&](auto ) { return failure (); });
632-
633- if (failed (conversionRes))
634- return conversionRes;
629+ Type origElemTy = inVals.front ().getType ();
630+
631+ TypeSwitch<Type>(origElemTy)
632+ .Case ([&](FloatType floatTy) {
633+ // TODO: Support FP4.
634+ Type dstType = int_ty (floatTy.getWidth ());
635+ assert (isValidTypeForSubGroupTranspose (dstType) &&
636+ " Expecting valid type" );
637+ llvm::transform (inVals, std::begin (inVals), [&](Value val) -> Value {
638+ return bitcast (val, dstType);
639+ });
640+ })
641+ .Case ([&](IntegerType intTy) {
642+ if (isValidTypeForSubGroupTranspose (intTy))
643+ return ;
644+ assert (intTy.getWidth () < minSubGroupTransposeWidth &&
645+ " Expecting type to extend to i8" );
646+ Type dstType = i8_ty;
647+ llvm::transform (inVals, std::begin (inVals), [&](Value val) -> Value {
648+ return zext (dstType, val);
649+ });
650+ })
651+ .Case ([&](LLVM::LLVMPointerType) {
652+ Type dstType = i64_ty;
653+ assert (isValidTypeForSubGroupTranspose (dstType) &&
654+ " i64 type should be supported" );
655+ llvm::transform (inVals, std::begin (inVals), [&](Value val) -> Value {
656+ return ptrtoint (dstType, val);
657+ });
658+ })
659+ .Default ([](auto ) { llvm_unreachable (" Unsupported type" ); });
635660
636661 SmallVector<Value> outVals =
637662 performSubGroupTranspose (loc, inVals, rewriter);
@@ -650,18 +675,15 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
650675 outVals, std::begin (outVals),
651676 [&](Value val) -> Value { return trunc (origElemTy, val); });
652677 })
653- .Case ([&](triton::PointerType ptrTy) {
654- Type llvmPtrTy = getTypeConverter ()->convertType (ptrTy);
655- assert (llvmPtrTy && " Type conversion failed" );
678+ .Case ([&](LLVM::LLVMPointerType ptrTy) {
656679 llvm::transform (
657680 outVals, std::begin (outVals),
658- [&](Value val) -> Value { return inttoptr (llvmPtrTy , val); });
681+ [&](Value val) -> Value { return inttoptr (ptrTy , val); });
659682 });
660683
661684 Value result = packLLElements (loc, getTypeConverter (), outVals, rewriter,
662685 op.getType ());
663686 rewriter.replaceOp (op, result);
664- return success ();
665687 }
666688
667689 VectorType
0 commit comments