3838#include " llvm/ADT/TypeSwitch.h"
3939#include " llvm/ADT/bit.h"
4040#include " llvm/Support/Casting.h"
41+ #include " llvm/Support/Format.h"
42+ #include " llvm/Support/FormatVariadic.h"
4143#include " llvm/Support/LogicalResult.h"
4244#include " llvm/Support/raw_ostream.h"
4345
@@ -63,6 +65,8 @@ constexpr unsigned packedSizeInBitsForDefault =
6365 16 ; // Minimum packing size per register for DPAS A.
6466constexpr unsigned packedSizeInBitsForDpasB =
6567 32 ; // Minimum packing size per register for DPAS B.
68+ static const char *const operandLayoutNamePrefix = " layout_operand_" ;
69+ static const char *const resultLayoutNamePrefix = " layout_result_" ;
6670
6771namespace {
6872
@@ -686,7 +690,8 @@ void attachLayoutAttributeToUsers(Value v, xegpu::LayoutAttr layout) {
686690 continue ;
687691 }
688692 // / For every other user, use a generic attribute name.
689- std::string attrName = " op" + std::to_string (operandNumber);
693+ std::string attrName =
694+ operandLayoutNamePrefix + std::to_string (operandNumber);
690695 owner->setAttr (attrName, layout);
691696 }
692697}
@@ -746,7 +751,7 @@ static LogicalResult attachLayoutAttributes(
746751 for (auto [i, r] : llvm::enumerate (op->getResults ())) {
747752 auto layoutInfo = getLayoutInfoForResult (r);
748753 if (layoutInfo) {
749- auto attrName = " r " + std::to_string (i);
754+ auto attrName = resultLayoutNamePrefix + std::to_string (i);
750755 op->setAttr (attrName, layoutInfo);
751756 // / Attach the layout attribute to the users of the result.
752757 attachLayoutAttributeToUsers (r, layoutInfo);
@@ -819,16 +824,29 @@ static VectorType getDistributedVectorType(xegpu::LayoutAttr layout,
819824 return distVecTyOrFailure.value ();
820825}
821826
822- static Value reconcileDistribtedVecType (Value orig, VectorType expected,
823- PatternRewriter &rewriter) {
827+ static Value reshapeDistributedVecType (Value orig, VectorType expected,
828+ PatternRewriter &rewriter) {
824829 assert (isa<VectorType>(orig.getType ()) && " expecting vector type" );
825830 auto origVecType = cast<VectorType>(orig.getType ());
826831 // / No need to reconcile if the types are the same.
827832 if (origVecType == expected)
828833 return orig;
829- auto castOp = rewriter.create <UnrealizedConversionCastOp>(orig.getLoc (),
830- expected, orig);
831- return castOp.getResult (0 );
834+ auto castOp =
835+ rewriter.create <vector::ShapeCastOp>(orig.getLoc (), expected, orig);
836+ return castOp.getResult ();
837+ }
838+
839+ static SmallVector<NamedAttribute>
840+ filterTemporaryLayoutAttributes (ArrayRef<NamedAttribute> attrs) {
841+ SmallVector<NamedAttribute> newAttrs;
842+ for (auto attr : attrs) {
843+ if (attr.getName ().strref ().contains (operandLayoutNamePrefix) ||
844+ attr.getName ().strref ().contains (resultLayoutNamePrefix)) {
845+ continue ;
846+ }
847+ newAttrs.push_back (attr);
848+ }
849+ return newAttrs;
832850}
833851
834852// / Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
@@ -903,11 +921,11 @@ struct MoveFuncBodyToWarpExecuteOnLane0
903921};
904922
905923// / Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
906- // / `gpu.warp_execute_on_lane_0` and put it after the warp op. The warp op will
907- // / still contain the original op that will not be used by the yield op (and
908- // / should be cleaned up later with dce). The yield op will bypass the
909- // / create_nd_tdesc's arguments. Tensor descriptor is not distributed because it
910- // / is a uniform value accorss all work items within the subgroup.
924+ // / `gpu.warp_execute_on_lane_0` and put it after the warp op. The warp op
925+ // / will still contain the original op that will not be used by the yield op
926+ // / (and should be cleaned up later with dce). The yield op will bypass the
927+ // / create_nd_tdesc's arguments. Tensor descriptor is not distributed because
928+ // / it is a uniform value accorss all work items within the subgroup.
911929// /
912930// / Example:
913931// /
@@ -985,10 +1003,10 @@ struct SubgroupOpTensorDescOp final : public gpu::WarpDistributionPattern {
9851003 }
9861004};
9871005
988- // / Sink a store_nd op at the end of enclosing `gpu.warp_execute_on_lane_0`. In
989- // / case arguments for the store are passed through the warp op interface they
990- // / would be propagated as returned values. Only the source vector for the store
991- // / is distributed according to sg_map attribute.
1006+ // / Sink a store_nd op at the end of enclosing `gpu.warp_execute_on_lane_0`.
1007+ // / In case arguments for the store are passed through the warp op interface
1008+ // / they would be propagated as returned values. Only the source vector for
1009+ // / the store is distributed according to sg_map attribute.
9921010// /
9931011// / Example:
9941012// /
@@ -1033,7 +1051,6 @@ struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern {
10331051 " Failed to distribute the type" );
10341052 VectorType distributedTypeByWarpOp =
10351053 distributedTypeByWarpOpOrFailure.value ();
1036- llvm::errs () << " distributed type: " << distributedTypeByWarpOp << " \n " ;
10371054
10381055 SmallVector<size_t > newRetIndices;
10391056 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
@@ -1050,21 +1067,21 @@ struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern {
10501067
10511068 // / For the value operand, there can be a conflict between the vector type
10521069 // / distributed by the warp op and (xegpu-specific) distributed type
1053- // / supported by the store op. We reconcile these mismatches by inserting a
1054- // / cast. These gets cancelled out later.
1070+ // / supported by the store op. We reconcile these mismatches by inserting
1071+ // / a cast. These gets cancelled out later.
10551072 auto storeNdDistributedValueTyOrFailure =
10561073 storeOp.getTensorDescType ().getDistributedVectorType ();
10571074 if (failed (storeNdDistributedValueTyOrFailure))
10581075 return rewriter.notifyMatchFailure (
10591076 storeOp, " Failed to get distributed vector type for the store op" );
1060- newStoreOperands.push_back (reconcileDistribtedVecType (
1077+ newStoreOperands.push_back (reshapeDistributedVecType (
10611078 newWarpOp.getResult (newRetIndices[0 ]),
10621079 storeNdDistributedValueTyOrFailure.value (), rewriter));
10631080 newStoreOperands.push_back (newWarpOp.getResult (newRetIndices[1 ]));
10641081
1065- rewriter.create <xegpu::StoreNdOp>(newWarpOp. getLoc (), TypeRange{},
1066- newStoreOperands);
1067- storeOp-> setDialectAttrs (storeOp->getDialectAttrs ( ));
1082+ rewriter.create <xegpu::StoreNdOp>(
1083+ newWarpOp. getLoc (), TypeRange{}, newStoreOperands,
1084+ filterTemporaryLayoutAttributes (storeOp->getAttrs () ));
10681085 rewriter.eraseOp (storeOp);
10691086 return success ();
10701087 }
@@ -1074,8 +1091,9 @@ struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern {
10741091// / `gpu.warp_execute_on_lane_0` and put it after the warp op.
10751092// / The warp op will still contain the original op that will not be used by
10761093// / the yield op (and should be cleaned up later with dce). The yield op will
1077- // / bypass the load's arguments. Only the loaded vector is distributed according
1078- // / to sg_map attribute and, tensor descriptor types is not distributed.
1094+ // / bypass the load's arguments. Only the loaded vector is distributed
1095+ // / according to sg_map attribute and, tensor descriptor types is not
1096+ // / distributed.
10791097// /
10801098// / Example:
10811099// /
@@ -1122,7 +1140,8 @@ struct SubgroupOpLoadNd final : public gpu::WarpDistributionPattern {
11221140
11231141 SmallVector<size_t > newRetIndices;
11241142 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
1125- rewriter, subgroupOp, /* new yielded values = */ loadOp.getTensorDesc (),
1143+ rewriter, subgroupOp,
1144+ /* new yielded values = */ loadOp.getTensorDesc (),
11261145 /* new yielded types = */ tensorDescTy, newRetIndices);
11271146
11281147 // / Create a new load op outside the warp op with the distributed vector
@@ -1135,13 +1154,14 @@ struct SubgroupOpLoadNd final : public gpu::WarpDistributionPattern {
11351154 loadOp, " Failed to get distributed vector type for the load op" );
11361155 Value newLoadOp = rewriter.create <xegpu::LoadNdOp>(
11371156 newWarpOp.getLoc (), loadNdDistValueTyOrFailure.value (),
1138- newWarpOp->getResult (newRetIndices[0 ]), loadOp->getAttrs ());
1157+ newWarpOp->getResult (newRetIndices[0 ]),
1158+ filterTemporaryLayoutAttributes (loadOp->getAttrs ()));
11391159 Value distributedVal = newWarpOp.getResult (operandIdx);
1140- // / There can be a conflict between the vector type distributed by the warp
1141- // / op and (xegpu-specific) distributed type supported by the load op. We
1142- // / reconcile these mismatches by inserting a cast.
1143- newLoadOp = reconcileDistribtedVecType (newLoadOp, distributedTypeByWarpOp,
1144- rewriter);
1160+ // / There can be a conflict between the vector type distributed by the
1161+ // / warp op and (xegpu-specific) distributed type supported by the load
1162+ // / op. We reconcile these mismatches by inserting a cast.
1163+ newLoadOp =
1164+ reshapeDistributedVecType (newLoadOp, distributedTypeByWarpOp, rewriter);
11451165 rewriter.replaceAllUsesWith (distributedVal, newLoadOp);
11461166 return success ();
11471167 }
@@ -1161,8 +1181,9 @@ struct SubgroupOpDpas final : public gpu::WarpDistributionPattern {
11611181 unsigned operandIdx = operand->getOperandNumber ();
11621182 xegpu::LayoutAttr layoutA = dpasOp.getALayoutAttr ();
11631183 xegpu::LayoutAttr layoutB = dpasOp.getBLayoutAttr ();
1184+ auto layoutCName = llvm::formatv (" {0}{1}" , resultLayoutNamePrefix, 0 ).str ();
11641185 xegpu::LayoutAttr layoutOut =
1165- dpasOp->getAttrOfType <xegpu::LayoutAttr>(" r0 " );
1186+ dpasOp->getAttrOfType <xegpu::LayoutAttr>(layoutCName );
11661187 if (!layoutA || !layoutB || !layoutOut)
11671188 return rewriter.notifyMatchFailure (
11681189 dpasOp,
@@ -1211,7 +1232,7 @@ struct SubgroupOpDpas final : public gpu::WarpDistributionPattern {
12111232 }
12121233
12131234 for (auto i : newRetIndices) {
1214- newDpasOperands.push_back (reconcileDistribtedVecType (
1235+ newDpasOperands.push_back (reshapeDistributedVecType (
12151236 newWarpOp.getResult (i),
12161237 newDpasOperandExpectedTypes[newDpasOperands.size ()], rewriter));
12171238 }
@@ -1220,7 +1241,7 @@ struct SubgroupOpDpas final : public gpu::WarpDistributionPattern {
12201241 newDpasOperands, dpasOp->getAttrs ());
12211242 Value disributedVal = newWarpOp.getResult (operandIdx);
12221243 // / Reconile the output type.
1223- disributedVal = reconcileDistribtedVecType (
1244+ disributedVal = reshapeDistributedVecType (
12241245 disributedVal,
12251246 getDistributedVectorType (layoutOut, dpasOp.getResultType ()), rewriter);
12261247 rewriter.replaceAllUsesWith (disributedVal, newDpasOp);
0 commit comments