@@ -696,9 +696,9 @@ class LayoutAttrAssignment {
696696 void assignToUsers (Value v, xegpu::LayoutAttr layout);
697697 xegpu::LayoutAttr getLayoutAttrForValue (Value v);
698698 LogicalResult resolveConflicts ();
699- function_ref<LayoutInfo(Value)>
700- getAnalysisResult; // Callable to get the layout of a value based on the
701- // layout propagation analysis.
699+ // Callable to get the layout of a value based on the layout propagation
700+ // analysis.
701+ function_ref<LayoutInfo(Value)> getAnalysisResult;
702702 Operation *top;
703703};
704704
@@ -851,22 +851,6 @@ FailureOr<VectorType> getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
851851 return VectorType::get (distributedShape, originalType.getElementType ());
852852}
853853
854- // / Get the distributed vector type for a source vector type according to a
855- // / xegpu::LayoutAttr.
856- static VectorType getDistributedVectorType (xegpu::LayoutAttr layout,
857- VectorType originalType) {
858- auto shape = originalType.getShape ();
859- auto distVecTyOrFailure =
860- xegpu::TensorDescType::get (shape, originalType.getElementType (),
861- /* array_length=*/ 1 , /* boundary_check=*/ true ,
862- /* memory_space=*/ xegpu::MemorySpace::Global,
863- layout)
864- .getDistributedVectorType ();
865- assert (llvm::succeeded (distVecTyOrFailure) &&
866- " Failed to compute distributed vector type for the given vector type" );
867- return distVecTyOrFailure.value ();
868- }
869-
870854// / Drop the layout attribute from the tensor descriptor type if layout is
871855// / present.
872856static xegpu::TensorDescType dropLayouts (xegpu::TensorDescType tensorDesc) {
@@ -1175,7 +1159,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
11751159 // / supported by the store op. Type mismatch must be resolved using
11761160 // / appropriate cast op.
11771161 auto storeNdDistributedValueTyOrFailure =
1178- storeOp.getTensorDescType (). getDistributedVectorType ( );
1162+ xegpu::getDistributedVectorType ( storeOp.getTensorDescType ());
11791163 if (failed (storeNdDistributedValueTyOrFailure))
11801164 return rewriter.notifyMatchFailure (
11811165 storeOp, " Failed to get distributed vector type for the store op" );
@@ -1263,7 +1247,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
12631247 // / type.
12641248 rewriter.setInsertionPointAfter (newWarpOp);
12651249 auto loadNdDistValueTyOrFailure =
1266- loadOp.getTensorDescType (). getDistributedVectorType ( );
1250+ xegpu::getDistributedVectorType ( loadOp.getTensorDescType ());
12671251 if (failed (loadNdDistValueTyOrFailure))
12681252 return rewriter.notifyMatchFailure (
12691253 loadOp, " Failed to get distributed vector type for the load op" );
@@ -1379,17 +1363,27 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
13791363 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
13801364 rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
13811365
1366+ FailureOr<VectorType> expectedDistLhsTyOrFailure =
1367+ xegpu::getDistributedVectorType (dpasOp.getLhsType (), layoutA);
1368+ FailureOr<VectorType> expectedDistRhsTyOrFailure =
1369+ xegpu::getDistributedVectorType (dpasOp.getRhsType (), layoutB);
1370+ FailureOr<VectorType> expectedDistResultTyOrFailure =
1371+ xegpu::getDistributedVectorType (dpasOp.getResultType (), layoutOut);
1372+ if (failed (expectedDistLhsTyOrFailure) ||
1373+ failed (expectedDistRhsTyOrFailure) ||
1374+ failed (expectedDistResultTyOrFailure))
1375+ return rewriter.notifyMatchFailure (
1376+ dpasOp,
1377+ " Failed to get distributed vector type for the dpas operands." );
13821378 // Create a new dpas op outside the warp op.
13831379 rewriter.setInsertionPointAfter (newWarpOp);
13841380 SmallVector<Value> newDpasOperands;
13851381 SmallVector<VectorType> newDpasOperandExpectedTypes;
1382+
13861383 // / Resolve the distributed types with the original types.
1387- newDpasOperandExpectedTypes.push_back (
1388- getDistributedVectorType (layoutA, dpasOp.getLhsType ()));
1389- newDpasOperandExpectedTypes.push_back (
1390- getDistributedVectorType (layoutB, dpasOp.getRhsType ()));
1391- auto distributedResultTy =
1392- getDistributedVectorType (layoutOut, dpasOp.getResultType ());
1384+ newDpasOperandExpectedTypes.push_back (expectedDistLhsTyOrFailure.value ());
1385+ newDpasOperandExpectedTypes.push_back (expectedDistRhsTyOrFailure.value ());
1386+ auto distributedResultTy = expectedDistResultTyOrFailure.value ();
13931387 if (dpasOp.getAcc ())
13941388 newDpasOperandExpectedTypes.push_back (distributedResultTy);
13951389
0 commit comments