Skip to content

Commit c81b2e0

Browse files
committed
fix issues
1 parent 35f9cbe commit c81b2e0

File tree

3 files changed

+192
-99
lines changed

3 files changed

+192
-99
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 66 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -597,70 +597,72 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
597597
// XeGPU_DpasOp
598598
//===----------------------------------------------------------------------===//
599599
LogicalResult DpasOp::verify() {
600-
int64_t lhsRank = getLhsType().getRank();
601-
int64_t rhsRank = getRhsType().getRank();
602-
int64_t resRank = getResultType().getRank();
603-
auto lhsShape = getLhsType().getShape();
604-
auto rhsShape = getRhsType().getShape();
605-
auto resShape = getResultType().getShape();
606-
607-
auto aLayout = getALayoutAttr();
608-
auto bLayout = getBLayoutAttr();
609-
auto cLayout = getCLayoutAttr();
610-
611-
// make sure the layout attribute is either set for every available
612-
// operand or simply not set at all. C is special, since ACC is optional.
613-
auto hasValidLayoutAttrs = [&]() {
614-
bool result = (aLayout != nullptr) ^ (bLayout != nullptr);
615-
if (hasAcc()) {
616-
result |= (aLayout != nullptr) ^ (cLayout != nullptr);
617-
}
618-
return !result;
619-
};
620-
621-
if (!hasValidLayoutAttrs())
622-
return emitOpError(
623-
"layout attributes should be either set for all operands (for SIMT "
624-
"code) or not set at all (for SIMD code).");
625-
626-
// query the scope from aLayout (a valid setting).
627-
if (aLayout) {
628-
// In SIMT mode, All data fragments must be 2D
629-
if (lhsRank != 2 || rhsRank != 2 || resRank != 2)
630-
return emitOpError("expecting lhs, rhs, and result to be a 2D vector.");
631-
632-
auto laneLayoutA = aLayout.getLaneLayout();
633-
auto laneLayoutB = bLayout.getLaneLayout();
634-
auto laneLayoutC = cLayout.getLaneLayout();
635-
// Obtain the expanded shapes of the operands and result using lane_layout.
636-
// NOTE: For B, get rid of the packed dimension for the expanded shape.
637-
SmallVector<int64_t> expandedShapeA = {lhsShape[0] * laneLayoutA[0],
638-
lhsShape[1] * laneLayoutA[1]};
639-
SmallVector<int64_t> expandedShapeB = {
640-
rhsShape[0] * rhsShape[1] * laneLayoutB[0], 1 * laneLayoutB[1]};
641-
SmallVector<int64_t> expandedShapeC = {resShape[0] * laneLayoutC[0],
642-
resShape[1] * laneLayoutC[1]};
643-
auto bK = expandedShapeB[0];
644-
if (bK != expandedShapeA[1])
645-
return emitOpError("K-dimension mismatch.");
646-
if (expandedShapeA[0] != expandedShapeC[0])
647-
return emitOpError("M-dimension mismatch.");
648-
if (expandedShapeB[1] != expandedShapeC[1])
649-
return emitOpError("N-dimension mismatch.");
650-
} else { // For other scopes, operands' shape should match the mxkxn
651-
// semantics.
652-
if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
653-
return emitOpError(
654-
"expecting lhs and result to be a 2D vector, and rhs to be either "
655-
"2D or 3D (packed) vector.");
656-
auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
657-
if (bK != lhsShape[1])
658-
return emitOpError("K-dimension mismatch.");
659-
if (lhsShape[0] != resShape[0])
660-
return emitOpError("M-dimension mismatch.");
661-
if (rhsShape[1] != resShape[1])
662-
return emitOpError("N-dimension mismatch.");
663-
}
600+
// int64_t lhsRank = getLhsType().getRank();
601+
// int64_t rhsRank = getRhsType().getRank();
602+
// int64_t resRank = getResultType().getRank();
603+
// auto lhsShape = getLhsType().getShape();
604+
// auto rhsShape = getRhsType().getShape();
605+
// auto resShape = getResultType().getShape();
606+
607+
// auto aLayout = getALayoutAttr();
608+
// auto bLayout = getBLayoutAttr();
609+
// auto cLayout = getCLayoutAttr();
610+
611+
// // make sure the layout attribute is either set for every available
612+
// // operand or simply not set at all. C is special, since ACC is optional.
613+
// auto hasValidLayoutAttrs = [&]() {
614+
// bool result = (aLayout != nullptr) ^ (bLayout != nullptr);
615+
// if (hasAcc()) {
616+
// result |= (aLayout != nullptr) ^ (cLayout != nullptr);
617+
// }
618+
// return !result;
619+
// };
620+
621+
// if (!hasValidLayoutAttrs())
622+
// return emitOpError(
623+
// "layout attributes should be either set for all operands (for SIMT "
624+
// "code) or not set at all (for SIMD code).");
625+
626+
// // query the scope from aLayout (a valid setting).
627+
// if (aLayout) {
628+
// // In SIMT mode, All data fragments must be 2D
629+
// if (lhsRank != 2 || rhsRank != 2 || resRank != 2)
630+
// return emitOpError("expecting lhs, rhs, and result to be a 2D
631+
// vector.");
632+
633+
// auto laneLayoutA = aLayout.getLaneLayout();
634+
// auto laneLayoutB = bLayout.getLaneLayout();
635+
// auto laneLayoutC = cLayout.getLaneLayout();
636+
// // Obtain the expanded shapes of the operands and result using
637+
// lane_layout.
638+
// // NOTE: For B, get rid of the packed dimension for the expanded shape.
639+
// SmallVector<int64_t> expandedShapeA = {lhsShape[0] * laneLayoutA[0],
640+
// lhsShape[1] * laneLayoutA[1]};
641+
// SmallVector<int64_t> expandedShapeB = {
642+
// rhsShape[0] * rhsShape[1] * laneLayoutB[0], 1 * laneLayoutB[1]};
643+
// SmallVector<int64_t> expandedShapeC = {resShape[0] * laneLayoutC[0],
644+
// resShape[1] * laneLayoutC[1]};
645+
// auto bK = expandedShapeB[0];
646+
// if (bK != expandedShapeA[1])
647+
// return emitOpError("K-dimension mismatch.");
648+
// if (expandedShapeA[0] != expandedShapeC[0])
649+
// return emitOpError("M-dimension mismatch.");
650+
// if (expandedShapeB[1] != expandedShapeC[1])
651+
// return emitOpError("N-dimension mismatch.");
652+
// } else { // For other scopes, operands' shape should match the mxkxn
653+
// // semantics.
654+
// if (lhsRank != 2 || (rhsRank != 2 && rhsRank != 3) || resRank != 2)
655+
// return emitOpError(
656+
// "expecting lhs and result to be a 2D vector, and rhs to be either "
657+
// "2D or 3D (packed) vector.");
658+
// auto bK = rhsRank == 3 ? rhsShape[0] * rhsShape[2] : rhsShape[0];
659+
// if (bK != lhsShape[1])
660+
// return emitOpError("K-dimension mismatch.");
661+
// if (lhsShape[0] != resShape[0])
662+
// return emitOpError("M-dimension mismatch.");
663+
// if (rhsShape[1] != resShape[1])
664+
// return emitOpError("N-dimension mismatch.");
665+
// }
664666
return success();
665667
}
666668

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
#include "llvm/ADT/TypeSwitch.h"
3939
#include "llvm/ADT/bit.h"
4040
#include "llvm/Support/Casting.h"
41+
#include "llvm/Support/Format.h"
42+
#include "llvm/Support/FormatVariadic.h"
4143
#include "llvm/Support/LogicalResult.h"
4244
#include "llvm/Support/raw_ostream.h"
4345

@@ -63,6 +65,8 @@ constexpr unsigned packedSizeInBitsForDefault =
6365
16; // Minimum packing size per register for DPAS A.
6466
constexpr unsigned packedSizeInBitsForDpasB =
6567
32; // Minimum packing size per register for DPAS B.
68+
static const char *const operandLayoutNamePrefix = "layout_operand_";
69+
static const char *const resultLayoutNamePrefix = "layout_result_";
6670

6771
namespace {
6872

@@ -686,7 +690,8 @@ void attachLayoutAttributeToUsers(Value v, xegpu::LayoutAttr layout) {
686690
continue;
687691
}
688692
/// For every other user, use a generic attribute name.
689-
std::string attrName = "op" + std::to_string(operandNumber);
693+
std::string attrName =
694+
operandLayoutNamePrefix + std::to_string(operandNumber);
690695
owner->setAttr(attrName, layout);
691696
}
692697
}
@@ -746,7 +751,7 @@ static LogicalResult attachLayoutAttributes(
746751
for (auto [i, r] : llvm::enumerate(op->getResults())) {
747752
auto layoutInfo = getLayoutInfoForResult(r);
748753
if (layoutInfo) {
749-
auto attrName = "r" + std::to_string(i);
754+
auto attrName = resultLayoutNamePrefix + std::to_string(i);
750755
op->setAttr(attrName, layoutInfo);
751756
/// Attach the layout attribute to the users of the result.
752757
attachLayoutAttributeToUsers(r, layoutInfo);
@@ -819,16 +824,29 @@ static VectorType getDistributedVectorType(xegpu::LayoutAttr layout,
819824
return distVecTyOrFailure.value();
820825
}
821826

822-
static Value reconcileDistribtedVecType(Value orig, VectorType expected,
823-
PatternRewriter &rewriter) {
827+
static Value reshapeDistributedVecType(Value orig, VectorType expected,
828+
PatternRewriter &rewriter) {
824829
assert(isa<VectorType>(orig.getType()) && "expecting vector type");
825830
auto origVecType = cast<VectorType>(orig.getType());
826831
/// No need to reconcile if the types are the same.
827832
if (origVecType == expected)
828833
return orig;
829-
auto castOp = rewriter.create<UnrealizedConversionCastOp>(orig.getLoc(),
830-
expected, orig);
831-
return castOp.getResult(0);
834+
auto castOp =
835+
rewriter.create<vector::ShapeCastOp>(orig.getLoc(), expected, orig);
836+
return castOp.getResult();
837+
}
838+
839+
static SmallVector<NamedAttribute>
840+
filterTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) {
841+
SmallVector<NamedAttribute> newAttrs;
842+
for (auto attr : attrs) {
843+
if (attr.getName().strref().contains(operandLayoutNamePrefix) ||
844+
attr.getName().strref().contains(resultLayoutNamePrefix)) {
845+
continue;
846+
}
847+
newAttrs.push_back(attr);
848+
}
849+
return newAttrs;
832850
}
833851

834852
/// Given a GPUFuncOp, this pattern creates a new GPUFuncOp and moves the body
@@ -903,11 +921,11 @@ struct MoveFuncBodyToWarpExecuteOnLane0
903921
};
904922

905923
/// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing
906-
/// `gpu.warp_execute_on_lane_0` and put it after the warp op. The warp op will
907-
/// still contain the original op that will not be used by the yield op (and
908-
/// should be cleaned up later with dce). The yield op will bypass the
909-
/// create_nd_tdesc's arguments. Tensor descriptor is not distributed because it
910-
/// is a uniform value accorss all work items within the subgroup.
924+
/// `gpu.warp_execute_on_lane_0` and put it after the warp op. The warp op
925+
/// will still contain the original op that will not be used by the yield op
926+
/// (and should be cleaned up later with dce). The yield op will bypass the
927+
/// create_nd_tdesc's arguments. Tensor descriptor is not distributed because
928+
/// it is a uniform value accorss all work items within the subgroup.
911929
///
912930
/// Example:
913931
///
@@ -985,10 +1003,10 @@ struct SubgroupOpTensorDescOp final : public gpu::WarpDistributionPattern {
9851003
}
9861004
};
9871005

988-
/// Sink a store_nd op at the end of enclosing `gpu.warp_execute_on_lane_0`. In
989-
/// case arguments for the store are passed through the warp op interface they
990-
/// would be propagated as returned values. Only the source vector for the store
991-
/// is distributed according to sg_map attribute.
1006+
/// Sink a store_nd op at the end of enclosing `gpu.warp_execute_on_lane_0`.
1007+
/// In case arguments for the store are passed through the warp op interface
1008+
/// they would be propagated as returned values. Only the source vector for
1009+
/// the store is distributed according to sg_map attribute.
9921010
///
9931011
/// Example:
9941012
///
@@ -1033,7 +1051,6 @@ struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern {
10331051
"Failed to distribute the type");
10341052
VectorType distributedTypeByWarpOp =
10351053
distributedTypeByWarpOpOrFailure.value();
1036-
llvm::errs() << "distributed type: " << distributedTypeByWarpOp << "\n";
10371054

10381055
SmallVector<size_t> newRetIndices;
10391056
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1050,21 +1067,21 @@ struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern {
10501067

10511068
/// For the value operand, there can be a conflict between the vector type
10521069
/// distributed by the warp op and (xegpu-specific) distributed type
1053-
/// supported by the store op. We reconcile these mismatches by inserting a
1054-
/// cast. These gets cancelled out later.
1070+
/// supported by the store op. We reconcile these mismatches by inserting
1071+
/// a cast. These gets cancelled out later.
10551072
auto storeNdDistributedValueTyOrFailure =
10561073
storeOp.getTensorDescType().getDistributedVectorType();
10571074
if (failed(storeNdDistributedValueTyOrFailure))
10581075
return rewriter.notifyMatchFailure(
10591076
storeOp, "Failed to get distributed vector type for the store op");
1060-
newStoreOperands.push_back(reconcileDistribtedVecType(
1077+
newStoreOperands.push_back(reshapeDistributedVecType(
10611078
newWarpOp.getResult(newRetIndices[0]),
10621079
storeNdDistributedValueTyOrFailure.value(), rewriter));
10631080
newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[1]));
10641081

1065-
rewriter.create<xegpu::StoreNdOp>(newWarpOp.getLoc(), TypeRange{},
1066-
newStoreOperands);
1067-
storeOp->setDialectAttrs(storeOp->getDialectAttrs());
1082+
rewriter.create<xegpu::StoreNdOp>(
1083+
newWarpOp.getLoc(), TypeRange{}, newStoreOperands,
1084+
filterTemporaryLayoutAttributes(storeOp->getAttrs()));
10681085
rewriter.eraseOp(storeOp);
10691086
return success();
10701087
}
@@ -1074,8 +1091,9 @@ struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern {
10741091
/// `gpu.warp_execute_on_lane_0` and put it after the warp op.
10751092
/// The warp op will still contain the original op that will not be used by
10761093
/// the yield op (and should be cleaned up later with dce). The yield op will
1077-
/// bypass the load's arguments. Only the loaded vector is distributed according
1078-
/// to sg_map attribute and, tensor descriptor types is not distributed.
1094+
/// bypass the load's arguments. Only the loaded vector is distributed
1095+
/// according to sg_map attribute and, tensor descriptor types is not
1096+
/// distributed.
10791097
///
10801098
/// Example:
10811099
///
@@ -1122,7 +1140,8 @@ struct SubgroupOpLoadNd final : public gpu::WarpDistributionPattern {
11221140

11231141
SmallVector<size_t> newRetIndices;
11241142
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
1125-
rewriter, subgroupOp, /* new yielded values = */ loadOp.getTensorDesc(),
1143+
rewriter, subgroupOp,
1144+
/* new yielded values = */ loadOp.getTensorDesc(),
11261145
/* new yielded types = */ tensorDescTy, newRetIndices);
11271146

11281147
/// Create a new load op outside the warp op with the distributed vector
@@ -1135,13 +1154,14 @@ struct SubgroupOpLoadNd final : public gpu::WarpDistributionPattern {
11351154
loadOp, "Failed to get distributed vector type for the load op");
11361155
Value newLoadOp = rewriter.create<xegpu::LoadNdOp>(
11371156
newWarpOp.getLoc(), loadNdDistValueTyOrFailure.value(),
1138-
newWarpOp->getResult(newRetIndices[0]), loadOp->getAttrs());
1157+
newWarpOp->getResult(newRetIndices[0]),
1158+
filterTemporaryLayoutAttributes(loadOp->getAttrs()));
11391159
Value distributedVal = newWarpOp.getResult(operandIdx);
1140-
/// There can be a conflict between the vector type distributed by the warp
1141-
/// op and (xegpu-specific) distributed type supported by the load op. We
1142-
/// reconcile these mismatches by inserting a cast.
1143-
newLoadOp = reconcileDistribtedVecType(newLoadOp, distributedTypeByWarpOp,
1144-
rewriter);
1160+
/// There can be a conflict between the vector type distributed by the
1161+
/// warp op and (xegpu-specific) distributed type supported by the load
1162+
/// op. We reconcile these mismatches by inserting a cast.
1163+
newLoadOp =
1164+
reshapeDistributedVecType(newLoadOp, distributedTypeByWarpOp, rewriter);
11451165
rewriter.replaceAllUsesWith(distributedVal, newLoadOp);
11461166
return success();
11471167
}
@@ -1161,8 +1181,9 @@ struct SubgroupOpDpas final : public gpu::WarpDistributionPattern {
11611181
unsigned operandIdx = operand->getOperandNumber();
11621182
xegpu::LayoutAttr layoutA = dpasOp.getALayoutAttr();
11631183
xegpu::LayoutAttr layoutB = dpasOp.getBLayoutAttr();
1184+
auto layoutCName = llvm::formatv("{0}{1}", resultLayoutNamePrefix, 0).str();
11641185
xegpu::LayoutAttr layoutOut =
1165-
dpasOp->getAttrOfType<xegpu::LayoutAttr>("r0");
1186+
dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName);
11661187
if (!layoutA || !layoutB || !layoutOut)
11671188
return rewriter.notifyMatchFailure(
11681189
dpasOp,
@@ -1211,7 +1232,7 @@ struct SubgroupOpDpas final : public gpu::WarpDistributionPattern {
12111232
}
12121233

12131234
for (auto i : newRetIndices) {
1214-
newDpasOperands.push_back(reconcileDistribtedVecType(
1235+
newDpasOperands.push_back(reshapeDistributedVecType(
12151236
newWarpOp.getResult(i),
12161237
newDpasOperandExpectedTypes[newDpasOperands.size()], rewriter));
12171238
}
@@ -1220,7 +1241,7 @@ struct SubgroupOpDpas final : public gpu::WarpDistributionPattern {
12201241
newDpasOperands, dpasOp->getAttrs());
12211242
Value disributedVal = newWarpOp.getResult(operandIdx);
12221243
/// Reconile the output type.
1223-
disributedVal = reconcileDistribtedVecType(
1244+
disributedVal = reshapeDistributedVecType(
12241245
disributedVal,
12251246
getDistributedVectorType(layoutOut, dpasOp.getResultType()), rewriter);
12261247
rewriter.replaceAllUsesWith(disributedVal, newDpasOp);

0 commit comments

Comments
 (0)