2121#include " mlir/IR/Attributes.h"
2222#include " mlir/IR/Builders.h"
2323#include " mlir/IR/BuiltinAttributes.h"
24+ #include " mlir/IR/BuiltinOps.h"
2425#include " mlir/IR/BuiltinTypes.h"
2526#include " mlir/IR/Operation.h"
2627#include " mlir/IR/PatternMatch.h"
@@ -679,17 +680,7 @@ void attachLayoutAttributeToUsers(Value v, xegpu::LayoutAttr layout) {
679680 for (OpOperand &user : v.getUses ()) {
680681 Operation *owner = user.getOwner ();
681682 unsigned operandNumber = user.getOperandNumber ();
682- // / If the user is a DpasOp, set A, B or C layout attributes.
683- if (auto dpasOp = dyn_cast<xegpu::DpasOp>(owner)) {
684- if (operandNumber == 0 )
685- dpasOp.setALayoutAttr (layout);
686- else if (operandNumber == 1 )
687- dpasOp.setBLayoutAttr (layout);
688- else if (operandNumber == 2 )
689- dpasOp.setCLayoutAttr (layout);
690- continue ;
691- }
692- // / For every other user, use a generic attribute name.
683+ // / Use a generic name for ease of querying the layout attribute later.
693684 std::string attrName =
694685 operandLayoutNamePrefix + std::to_string (operandNumber);
695686 owner->setAttr (attrName, layout);
@@ -824,18 +815,66 @@ static VectorType getDistributedVectorType(xegpu::LayoutAttr layout,
824815 return distVecTyOrFailure.value ();
825816}
826817
827- static Value reshapeDistributedVecType (Value orig, VectorType expected,
828- PatternRewriter &rewriter) {
829- assert (isa<VectorType>(orig.getType ()) && " expecting vector type" );
830- auto origVecType = cast<VectorType>(orig.getType ());
831- // / No need to reconcile if the types are the same.
832- if (origVecType == expected)
818+ static xegpu::TensorDescType dropLayouts (xegpu::TensorDescType tensorDesc) {
819+ return xegpu::TensorDescType::get (
820+ tensorDesc.getContext (), tensorDesc.getShape (),
821+ tensorDesc.getElementType (), tensorDesc.getEncoding (),
822+ xegpu::LayoutAttr ());
823+ }
824+
825+ template <typename T>
826+ static Value resolveDistributedTy (Value orig, T expected,
827+ PatternRewriter &rewriter) {
828+ // / If orig and expected types are the same, return orig.
829+ if (orig.getType () == expected)
833830 return orig;
834- auto castOp =
835- rewriter.create <vector::ShapeCastOp>(orig.getLoc (), expected, orig);
836- return castOp.getResult ();
831+ // / If orig is a vector type, create a shape cast op to reconcile the types.
832+ if (auto origVecType = isa<VectorType>(orig.getType ())) {
833+ auto castOp =
834+ rewriter.create <vector::ShapeCastOp>(orig.getLoc (), expected, orig);
835+ return castOp.getResult ();
836+ }
837+ // / If orig is a tensor descriptor type, create an unrealized conversion cast
838+ // / op to reconcile the types.
839+ if (auto origTensorDescTy = isa<xegpu::TensorDescType>(orig.getType ())) {
840+ auto castOp = rewriter.create <UnrealizedConversionCastOp>(orig.getLoc (),
841+ expected, orig);
842+ return castOp.getResult (0 );
843+ }
844+ llvm_unreachable (" Unsupported type for reconciliation" );
845+ return orig;
837846}
838847
848+ // static Value reconcileDistributedTensorDescTy(Value orig,
849+ // xegpu::TensorDescType expected,
850+ // PatternRewriter &rewriter) {
851+ // assert(isa<xegpu::TensorDescType>(orig.getType()) &&
852+ // "expecting tensor descriptor type");
853+ // auto origTensorDescTy = cast<xegpu::TensorDescType>(orig.getType());
854+ // /// No need to reconcile if the types are the same.
855+ // if (origTensorDescTy == expected)
856+ // return orig;
857+ // auto castOp = rewriter.create<UnrealizedConversionCastOp>(orig.getLoc(),
858+ // expected, orig);
859+ // return castOp.getResult(0);
860+ // }
861+
862+ // // unify above 2 functions with a template
863+ // template <typename T>
864+ // static Value reconcileDistributedType(Value orig, T expected,
865+ // PatternRewriter &rewriter) {
866+ // if constexpr (std::is_same_v<T, VectorType>) {
867+ // return reconcileDistributedVecType(orig, expected, rewriter);
868+ // } else if constexpr (std::is_same_v<T, xegpu::TensorDescType>) {
869+ // return reconcileDistributedTensorDescTy(orig, expected, rewriter);
870+ // } else {
871+ // static_assert(llvm::is_one_of<T, VectorType,
872+ // xegpu::TensorDescType>::value,
873+ // "Unsupported type for reconciliation");
874+ // }
875+ // return orig;
876+ // }
877+
839878static SmallVector<NamedAttribute>
840879filterTemporaryLayoutAttributes (ArrayRef<NamedAttribute> attrs) {
841880 SmallVector<NamedAttribute> newAttrs;
@@ -951,7 +990,7 @@ struct MoveFuncBodyToWarpExecuteOnLane0
951990// / -> !xegpu.tensor_desc<4x8xf32>
952991// /
953992// / ```
954- struct SubgroupOpTensorDescOp final : public gpu::WarpDistributionPattern {
993+ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
955994 using gpu::WarpDistributionPattern::WarpDistributionPattern;
956995 LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp,
957996 PatternRewriter &rewriter) const override {
@@ -993,8 +1032,11 @@ struct SubgroupOpTensorDescOp final : public gpu::WarpDistributionPattern {
9931032 newDescOperands.push_back (newWarpOp.getResult (i));
9941033 }
9951034 rewriter.setInsertionPointAfter (newWarpOp);
1035+ auto distributedTensorDescTy =
1036+ dropLayouts (descOp.getType ()); // / Distributed tensor descriptor type
1037+ // / does not contain layout info.
9961038 auto newDescOp = rewriter.create <xegpu::CreateNdDescOp>(
997- newWarpOp.getLoc (), descOp. getType () , newDescOperands,
1039+ newWarpOp.getLoc (), distributedTensorDescTy , newDescOperands,
9981040 descOp->getAttrs ());
9991041
10001042 Value distributedVal = newWarpOp.getResult (operandIdx);
@@ -1027,7 +1069,7 @@ struct SubgroupOpTensorDescOp final : public gpu::WarpDistributionPattern {
10271069// / !xegpu.tensor_desc<4x8xf32>
10281070// /
10291071// / ```
1030- struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern {
1072+ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
10311073 using gpu::WarpDistributionPattern::WarpDistributionPattern;
10321074 LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp,
10331075 PatternRewriter &rewriter) const override {
@@ -1065,19 +1107,24 @@ struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern {
10651107 rewriter.setInsertionPointAfter (newWarpOp);
10661108 SmallVector<Value> newStoreOperands;
10671109
1068- // / For the value operand, there can be a conflict between the vector type
1110+ // / For the value operand, there can be a mismatch between the vector type
10691111 // / distributed by the warp op and (xegpu-specific) distributed type
1070- // / supported by the store op. We reconcile these mismatches by inserting
1071- // / a cast. These gets cancelled out later .
1112+ // / supported by the store op. Type mismatch must be resolved using
1113+ // / appropriate cast op .
10721114 auto storeNdDistributedValueTyOrFailure =
10731115 storeOp.getTensorDescType ().getDistributedVectorType ();
10741116 if (failed (storeNdDistributedValueTyOrFailure))
10751117 return rewriter.notifyMatchFailure (
10761118 storeOp, " Failed to get distributed vector type for the store op" );
1077- newStoreOperands.push_back (reshapeDistributedVecType (
1119+ newStoreOperands.push_back (resolveDistributedTy (
10781120 newWarpOp.getResult (newRetIndices[0 ]),
10791121 storeNdDistributedValueTyOrFailure.value (), rewriter));
1080- newStoreOperands.push_back (newWarpOp.getResult (newRetIndices[1 ]));
1122+ // / For the tensor descriptor operand, the layout attibute is dropped after
1123+ // / distribution. Types needs to be resolved in this case also.
1124+ auto distributedTensorDescTy = dropLayouts (storeOp.getTensorDescType ());
1125+ newStoreOperands.push_back (
1126+ resolveDistributedTy (newWarpOp.getResult (newRetIndices[1 ]),
1127+ distributedTensorDescTy, rewriter));
10811128
10821129 rewriter.create <xegpu::StoreNdOp>(
10831130 newWarpOp.getLoc (), TypeRange{}, newStoreOperands,
@@ -1117,7 +1164,7 @@ struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern {
11171164// / %ld = xegpu.load_nd %r#0: !xegpu.tensor_desc<4x8xf32> -> vector<4x1xf32>
11181165// /
11191166// / ```
1120- struct SubgroupOpLoadNd final : public gpu::WarpDistributionPattern {
1167+ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
11211168 using gpu::WarpDistributionPattern::WarpDistributionPattern;
11221169 LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp,
11231170 PatternRewriter &rewriter) const override {
@@ -1161,13 +1208,13 @@ struct SubgroupOpLoadNd final : public gpu::WarpDistributionPattern {
11611208 // / warp op and (xegpu-specific) distributed type supported by the load
11621209 // / op. We reconcile these mismatches by inserting a cast.
11631210 newLoadOp =
1164- reshapeDistributedVecType (newLoadOp, distributedTypeByWarpOp, rewriter);
1211+ resolveDistributedTy (newLoadOp, distributedTypeByWarpOp, rewriter);
11651212 rewriter.replaceAllUsesWith (distributedVal, newLoadOp);
11661213 return success ();
11671214 }
11681215};
11691216
1170- struct SubgroupOpDpas final : public gpu::WarpDistributionPattern {
1217+ struct DpasDistribution final : public gpu::WarpDistributionPattern {
11711218 using gpu::WarpDistributionPattern::WarpDistributionPattern;
11721219 LogicalResult matchAndRewrite (gpu::WarpExecuteOnLane0Op subgroupOp,
11731220 PatternRewriter &rewriter) const override {
@@ -1179,15 +1226,21 @@ struct SubgroupOpDpas final : public gpu::WarpDistributionPattern {
11791226
11801227 auto dpasOp = operand->get ().getDefiningOp <xegpu::DpasOp>();
11811228 unsigned operandIdx = operand->getOperandNumber ();
1182- xegpu::LayoutAttr layoutA = dpasOp.getALayoutAttr ();
1183- xegpu::LayoutAttr layoutB = dpasOp.getBLayoutAttr ();
1229+ auto layoutAName =
1230+ llvm::formatv (" {0}{1}" , operandLayoutNamePrefix, 0 ).str ();
1231+ auto layoutBName =
1232+ llvm::formatv (" {0}{1}" , operandLayoutNamePrefix, 1 ).str ();
11841233 auto layoutCName = llvm::formatv (" {0}{1}" , resultLayoutNamePrefix, 0 ).str ();
1234+ xegpu::LayoutAttr layoutA =
1235+ dpasOp->getAttrOfType <xegpu::LayoutAttr>(layoutAName);
1236+ xegpu::LayoutAttr layoutB =
1237+ dpasOp->getAttrOfType <xegpu::LayoutAttr>(layoutBName);
11851238 xegpu::LayoutAttr layoutOut =
11861239 dpasOp->getAttrOfType <xegpu::LayoutAttr>(layoutCName);
11871240 if (!layoutA || !layoutB || !layoutOut)
11881241 return rewriter.notifyMatchFailure (
11891242 dpasOp,
1190- " the xegpu::Dpas op lacks sg_map attribute for A, B or output" );
1243+ " the xegpu::Dpas op lacks layout attribute for A, B or output" );
11911244
11921245 auto distLhsTypeByWarpOpOrFailure =
11931246 getDistVecTypeBasedOnLaneLayout (layoutA, dpasOp.getLhsType ());
@@ -1232,7 +1285,7 @@ struct SubgroupOpDpas final : public gpu::WarpDistributionPattern {
12321285 }
12331286
12341287 for (auto i : newRetIndices) {
1235- newDpasOperands.push_back (reshapeDistributedVecType (
1288+ newDpasOperands.push_back (resolveDistributedTy (
12361289 newWarpOp.getResult (i),
12371290 newDpasOperandExpectedTypes[newDpasOperands.size ()], rewriter));
12381291 }
@@ -1241,7 +1294,7 @@ struct SubgroupOpDpas final : public gpu::WarpDistributionPattern {
12411294 newDpasOperands, dpasOp->getAttrs ());
12421295 Value disributedVal = newWarpOp.getResult (operandIdx);
12431296 // / Reconile the output type.
1244- disributedVal = reshapeDistributedVecType (
1297+ disributedVal = resolveDistributedTy (
12451298 disributedVal,
12461299 getDistributedVectorType (layoutOut, dpasOp.getResultType ()), rewriter);
12471300 rewriter.replaceAllUsesWith (disributedVal, newDpasOp);
@@ -1266,8 +1319,8 @@ struct XeGPUSubgroupDistributePass final
12661319
12671320void xegpu::populateXeGPUSubgroupDistributePatterns (
12681321 RewritePatternSet &patterns) {
1269- patterns.add <SubgroupOpTensorDescOp, SubgroupOpStoreNd, SubgroupOpLoadNd ,
1270- SubgroupOpDpas >(patterns.getContext ());
1322+ patterns.add <CreateNdDescDistribution, StoreNdDistribution ,
1323+ LoadNdDistribution, DpasDistribution >(patterns.getContext ());
12711324}
12721325
12731326void XeGPUSubgroupDistributePass::runOnOperation () {
0 commit comments