Skip to content

Commit 2f2ec10

Browse files
committed
fix issues
1 parent 03bfe08 commit 2f2ec10

File tree

3 files changed

+151
-98
lines changed

3 files changed

+151
-98
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -676,12 +676,12 @@ void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
676676
// XeGPU_DpasOp
677677
//===----------------------------------------------------------------------===//
678678
LogicalResult DpasOp::verify() {
679-
// int64_t lhsRank = getLhsType().getRank();
680-
// int64_t rhsRank = getRhsType().getRank();
681-
// int64_t resRank = getResultType().getRank();
682-
// auto lhsShape = getLhsType().getShape();
683-
// auto rhsShape = getRhsType().getShape();
684-
// auto resShape = getResultType().getShape();
679+
int64_t lhsRank = getLhsType().getRank();
680+
int64_t rhsRank = getRhsType().getRank();
681+
int64_t resRank = getResultType().getRank();
682+
auto lhsShape = getLhsType().getShape();
683+
auto rhsShape = getRhsType().getShape();
684+
auto resShape = getResultType().getShape();
685685

686686
if (getAcc()) {
687687
if (getAcc().getType() != getResultType())

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 91 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/IR/Attributes.h"
2222
#include "mlir/IR/Builders.h"
2323
#include "mlir/IR/BuiltinAttributes.h"
24+
#include "mlir/IR/BuiltinOps.h"
2425
#include "mlir/IR/BuiltinTypes.h"
2526
#include "mlir/IR/Operation.h"
2627
#include "mlir/IR/PatternMatch.h"
@@ -679,17 +680,7 @@ void attachLayoutAttributeToUsers(Value v, xegpu::LayoutAttr layout) {
679680
for (OpOperand &user : v.getUses()) {
680681
Operation *owner = user.getOwner();
681682
unsigned operandNumber = user.getOperandNumber();
682-
/// If the user is a DpasOp, set A, B or C layout attributes.
683-
if (auto dpasOp = dyn_cast<xegpu::DpasOp>(owner)) {
684-
if (operandNumber == 0)
685-
dpasOp.setALayoutAttr(layout);
686-
else if (operandNumber == 1)
687-
dpasOp.setBLayoutAttr(layout);
688-
else if (operandNumber == 2)
689-
dpasOp.setCLayoutAttr(layout);
690-
continue;
691-
}
692-
/// For every other user, use a generic attribute name.
683+
/// Use a generic name for ease of querying the layout attribute later.
693684
std::string attrName =
694685
operandLayoutNamePrefix + std::to_string(operandNumber);
695686
owner->setAttr(attrName, layout);
@@ -824,18 +815,66 @@ static VectorType getDistributedVectorType(xegpu::LayoutAttr layout,
824815
return distVecTyOrFailure.value();
825816
}
826817

827-
static Value reshapeDistributedVecType(Value orig, VectorType expected,
828-
PatternRewriter &rewriter) {
829-
assert(isa<VectorType>(orig.getType()) && "expecting vector type");
830-
auto origVecType = cast<VectorType>(orig.getType());
831-
/// No need to reconcile if the types are the same.
832-
if (origVecType == expected)
818+
static xegpu::TensorDescType dropLayouts(xegpu::TensorDescType tensorDesc) {
819+
return xegpu::TensorDescType::get(
820+
tensorDesc.getContext(), tensorDesc.getShape(),
821+
tensorDesc.getElementType(), tensorDesc.getEncoding(),
822+
xegpu::LayoutAttr());
823+
}
824+
825+
template <typename T>
826+
static Value resolveDistributedTy(Value orig, T expected,
827+
PatternRewriter &rewriter) {
828+
/// If orig and expected types are the same, return orig.
829+
if (orig.getType() == expected)
833830
return orig;
834-
auto castOp =
835-
rewriter.create<vector::ShapeCastOp>(orig.getLoc(), expected, orig);
836-
return castOp.getResult();
831+
/// If orig is a vector type, create a shape cast op to reconcile the types.
832+
if (auto origVecType = isa<VectorType>(orig.getType())) {
833+
auto castOp =
834+
rewriter.create<vector::ShapeCastOp>(orig.getLoc(), expected, orig);
835+
return castOp.getResult();
836+
}
837+
/// If orig is a tensor descriptor type, create an unrealized conversion cast
838+
/// op to reconcile the types.
839+
if (auto origTensorDescTy = isa<xegpu::TensorDescType>(orig.getType())) {
840+
auto castOp = rewriter.create<UnrealizedConversionCastOp>(orig.getLoc(),
841+
expected, orig);
842+
return castOp.getResult(0);
843+
}
844+
llvm_unreachable("Unsupported type for reconciliation");
845+
return orig;
837846
}
838847

848+
// static Value reconcileDistributedTensorDescTy(Value orig,
849+
// xegpu::TensorDescType expected,
850+
// PatternRewriter &rewriter) {
851+
// assert(isa<xegpu::TensorDescType>(orig.getType()) &&
852+
// "expecting tensor descriptor type");
853+
// auto origTensorDescTy = cast<xegpu::TensorDescType>(orig.getType());
854+
// /// No need to reconcile if the types are the same.
855+
// if (origTensorDescTy == expected)
856+
// return orig;
857+
// auto castOp = rewriter.create<UnrealizedConversionCastOp>(orig.getLoc(),
858+
// expected, orig);
859+
// return castOp.getResult(0);
860+
// }
861+
862+
// // unify above 2 functions with a template
863+
// template <typename T>
864+
// static Value reconcileDistributedType(Value orig, T expected,
865+
// PatternRewriter &rewriter) {
866+
// if constexpr (std::is_same_v<T, VectorType>) {
867+
// return reconcileDistributedVecType(orig, expected, rewriter);
868+
// } else if constexpr (std::is_same_v<T, xegpu::TensorDescType>) {
869+
// return reconcileDistributedTensorDescTy(orig, expected, rewriter);
870+
// } else {
871+
// static_assert(llvm::is_one_of<T, VectorType,
872+
// xegpu::TensorDescType>::value,
873+
// "Unsupported type for reconciliation");
874+
// }
875+
// return orig;
876+
// }
877+
839878
static SmallVector<NamedAttribute>
840879
filterTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) {
841880
SmallVector<NamedAttribute> newAttrs;
@@ -951,7 +990,7 @@ struct MoveFuncBodyToWarpExecuteOnLane0
951990
/// -> !xegpu.tensor_desc<4x8xf32>
952991
///
953992
/// ```
954-
struct SubgroupOpTensorDescOp final : public gpu::WarpDistributionPattern {
993+
struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
955994
using gpu::WarpDistributionPattern::WarpDistributionPattern;
956995
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
957996
PatternRewriter &rewriter) const override {
@@ -993,8 +1032,11 @@ struct SubgroupOpTensorDescOp final : public gpu::WarpDistributionPattern {
9931032
newDescOperands.push_back(newWarpOp.getResult(i));
9941033
}
9951034
rewriter.setInsertionPointAfter(newWarpOp);
1035+
auto distributedTensorDescTy =
1036+
dropLayouts(descOp.getType()); /// Distributed tensor descriptor type
1037+
/// does not contain layout info.
9961038
auto newDescOp = rewriter.create<xegpu::CreateNdDescOp>(
997-
newWarpOp.getLoc(), descOp.getType(), newDescOperands,
1039+
newWarpOp.getLoc(), distributedTensorDescTy, newDescOperands,
9981040
descOp->getAttrs());
9991041

10001042
Value distributedVal = newWarpOp.getResult(operandIdx);
@@ -1027,7 +1069,7 @@ struct SubgroupOpTensorDescOp final : public gpu::WarpDistributionPattern {
10271069
/// !xegpu.tensor_desc<4x8xf32>
10281070
///
10291071
/// ```
1030-
struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern {
1072+
struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
10311073
using gpu::WarpDistributionPattern::WarpDistributionPattern;
10321074
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
10331075
PatternRewriter &rewriter) const override {
@@ -1065,19 +1107,24 @@ struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern {
10651107
rewriter.setInsertionPointAfter(newWarpOp);
10661108
SmallVector<Value> newStoreOperands;
10671109

1068-
/// For the value operand, there can be a conflict between the vector type
1110+
/// For the value operand, there can be a mismatch between the vector type
10691111
/// distributed by the warp op and (xegpu-specific) distributed type
1070-
/// supported by the store op. We reconcile these mismatches by inserting
1071-
/// a cast. These gets cancelled out later.
1112+
/// supported by the store op. Type mismatch must be resolved using
1113+
/// appropriate cast op.
10721114
auto storeNdDistributedValueTyOrFailure =
10731115
storeOp.getTensorDescType().getDistributedVectorType();
10741116
if (failed(storeNdDistributedValueTyOrFailure))
10751117
return rewriter.notifyMatchFailure(
10761118
storeOp, "Failed to get distributed vector type for the store op");
1077-
newStoreOperands.push_back(reshapeDistributedVecType(
1119+
newStoreOperands.push_back(resolveDistributedTy(
10781120
newWarpOp.getResult(newRetIndices[0]),
10791121
storeNdDistributedValueTyOrFailure.value(), rewriter));
1080-
newStoreOperands.push_back(newWarpOp.getResult(newRetIndices[1]));
1122+
/// For the tensor descriptor operand, the layout attibute is dropped after
1123+
/// distribution. Types needs to be resolved in this case also.
1124+
auto distributedTensorDescTy = dropLayouts(storeOp.getTensorDescType());
1125+
newStoreOperands.push_back(
1126+
resolveDistributedTy(newWarpOp.getResult(newRetIndices[1]),
1127+
distributedTensorDescTy, rewriter));
10811128

10821129
rewriter.create<xegpu::StoreNdOp>(
10831130
newWarpOp.getLoc(), TypeRange{}, newStoreOperands,
@@ -1117,7 +1164,7 @@ struct SubgroupOpStoreNd final : public gpu::WarpDistributionPattern {
11171164
/// %ld = xegpu.load_nd %r#0: !xegpu.tensor_desc<4x8xf32> -> vector<4x1xf32>
11181165
///
11191166
/// ```
1120-
struct SubgroupOpLoadNd final : public gpu::WarpDistributionPattern {
1167+
struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
11211168
using gpu::WarpDistributionPattern::WarpDistributionPattern;
11221169
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
11231170
PatternRewriter &rewriter) const override {
@@ -1161,13 +1208,13 @@ struct SubgroupOpLoadNd final : public gpu::WarpDistributionPattern {
11611208
/// warp op and (xegpu-specific) distributed type supported by the load
11621209
/// op. We reconcile these mismatches by inserting a cast.
11631210
newLoadOp =
1164-
reshapeDistributedVecType(newLoadOp, distributedTypeByWarpOp, rewriter);
1211+
resolveDistributedTy(newLoadOp, distributedTypeByWarpOp, rewriter);
11651212
rewriter.replaceAllUsesWith(distributedVal, newLoadOp);
11661213
return success();
11671214
}
11681215
};
11691216

1170-
struct SubgroupOpDpas final : public gpu::WarpDistributionPattern {
1217+
struct DpasDistribution final : public gpu::WarpDistributionPattern {
11711218
using gpu::WarpDistributionPattern::WarpDistributionPattern;
11721219
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op subgroupOp,
11731220
PatternRewriter &rewriter) const override {
@@ -1179,15 +1226,21 @@ struct SubgroupOpDpas final : public gpu::WarpDistributionPattern {
11791226

11801227
auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
11811228
unsigned operandIdx = operand->getOperandNumber();
1182-
xegpu::LayoutAttr layoutA = dpasOp.getALayoutAttr();
1183-
xegpu::LayoutAttr layoutB = dpasOp.getBLayoutAttr();
1229+
auto layoutAName =
1230+
llvm::formatv("{0}{1}", operandLayoutNamePrefix, 0).str();
1231+
auto layoutBName =
1232+
llvm::formatv("{0}{1}", operandLayoutNamePrefix, 1).str();
11841233
auto layoutCName = llvm::formatv("{0}{1}", resultLayoutNamePrefix, 0).str();
1234+
xegpu::LayoutAttr layoutA =
1235+
dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
1236+
xegpu::LayoutAttr layoutB =
1237+
dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutBName);
11851238
xegpu::LayoutAttr layoutOut =
11861239
dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutCName);
11871240
if (!layoutA || !layoutB || !layoutOut)
11881241
return rewriter.notifyMatchFailure(
11891242
dpasOp,
1190-
"the xegpu::Dpas op lacks sg_map attribute for A, B or output");
1243+
"the xegpu::Dpas op lacks layout attribute for A, B or output");
11911244

11921245
auto distLhsTypeByWarpOpOrFailure =
11931246
getDistVecTypeBasedOnLaneLayout(layoutA, dpasOp.getLhsType());
@@ -1232,7 +1285,7 @@ struct SubgroupOpDpas final : public gpu::WarpDistributionPattern {
12321285
}
12331286

12341287
for (auto i : newRetIndices) {
1235-
newDpasOperands.push_back(reshapeDistributedVecType(
1288+
newDpasOperands.push_back(resolveDistributedTy(
12361289
newWarpOp.getResult(i),
12371290
newDpasOperandExpectedTypes[newDpasOperands.size()], rewriter));
12381291
}
@@ -1241,7 +1294,7 @@ struct SubgroupOpDpas final : public gpu::WarpDistributionPattern {
12411294
newDpasOperands, dpasOp->getAttrs());
12421295
Value disributedVal = newWarpOp.getResult(operandIdx);
12431296
/// Reconile the output type.
1244-
disributedVal = reshapeDistributedVecType(
1297+
disributedVal = resolveDistributedTy(
12451298
disributedVal,
12461299
getDistributedVectorType(layoutOut, dpasOp.getResultType()), rewriter);
12471300
rewriter.replaceAllUsesWith(disributedVal, newDpasOp);
@@ -1266,8 +1319,8 @@ struct XeGPUSubgroupDistributePass final
12661319

12671320
void xegpu::populateXeGPUSubgroupDistributePatterns(
12681321
RewritePatternSet &patterns) {
1269-
patterns.add<SubgroupOpTensorDescOp, SubgroupOpStoreNd, SubgroupOpLoadNd,
1270-
SubgroupOpDpas>(patterns.getContext());
1322+
patterns.add<CreateNdDescDistribution, StoreNdDistribution,
1323+
LoadNdDistribution, DpasDistribution>(patterns.getContext());
12711324
}
12721325

12731326
void XeGPUSubgroupDistributePass::runOnOperation() {

mlir/test/Dialect/XeGPU/subgroup-distribution.mlir

Lines changed: 54 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,63 +8,63 @@ gpu.func @test_store_nd_1d(%arg0: memref<16xf32>){
88
}
99
}
1010

11-
// -----
12-
gpu.module @test {
13-
gpu.func @test_store_nd_2d(%arg0: memref<16x16xf16>){
14-
%c0 = arith.constant 0 : index
15-
%1 = arith.constant dense<1.000000e+00> : vector<16x16xf16>
16-
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
17-
xegpu.store_nd %1, %0 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
18-
gpu.return
19-
}
20-
}
11+
// // -----
12+
// gpu.module @test {
13+
// gpu.func @test_store_nd_2d(%arg0: memref<16x16xf16>){
14+
// %c0 = arith.constant 0 : index
15+
// %1 = arith.constant dense<1.000000e+00> : vector<16x16xf16>
16+
// %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
17+
// xegpu.store_nd %1, %0 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
18+
// gpu.return
19+
// }
20+
// }
2121

2222

2323

24-
// -----
25-
gpu.module @test {
26-
gpu.func @test_load_nd_1d(%arg0: memref<16xf32>, %arg1: memref<16xf32>){
27-
%c0 = arith.constant 0 : index
28-
%0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
29-
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
30-
%2 = xegpu.create_nd_tdesc %arg1[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
31-
xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
32-
gpu.return
33-
}
34-
}
24+
// // -----
25+
// gpu.module @test {
26+
// gpu.func @test_load_nd_1d(%arg0: memref<16xf32>, %arg1: memref<16xf32>){
27+
// %c0 = arith.constant 0 : index
28+
// %0 = xegpu.create_nd_tdesc %arg0[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
29+
// %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
30+
// %2 = xegpu.create_nd_tdesc %arg1[%c0] : memref<16xf32> -> !xegpu.tensor_desc<16xf32>
31+
// xegpu.store_nd %1, %2 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
32+
// gpu.return
33+
// }
34+
// }
3535

36-
// -----
37-
gpu.module @test {
38-
gpu.func @test_load_nd_2d(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
39-
%c0 = arith.constant 0 : index
40-
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
41-
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
42-
%2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
43-
xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
44-
gpu.return
45-
}
46-
}
36+
// // -----
37+
// gpu.module @test {
38+
// gpu.func @test_load_nd_2d(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
39+
// %c0 = arith.constant 0 : index
40+
// %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
41+
// %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
42+
// %2 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
43+
// xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
44+
// gpu.return
45+
// }
46+
// }
4747

48-
// -----
49-
gpu.module @test {
50-
gpu.func @test_load_nd_array_length(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
51-
%c0 = arith.constant 0 : index
52-
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
53-
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x16x16xf16>
54-
%2 = vector.extract %1[%c0] : vector<16x16xf16> from vector<2x16x16xf16>
55-
%3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
56-
xegpu.store_nd %2, %3 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
57-
gpu.return
58-
}
59-
}
48+
// // -----
49+
// gpu.module @test {
50+
// gpu.func @test_load_nd_array_length(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>){
51+
// %c0 = arith.constant 0 : index
52+
// %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>>
53+
// %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<array_length = 2 : i64>> -> vector<2x16x16xf16>
54+
// %2 = vector.extract %1[%c0] : vector<16x16xf16> from vector<2x16x16xf16>
55+
// %3 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16>
56+
// xegpu.store_nd %2, %3 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
57+
// gpu.return
58+
// }
59+
// }
6060

61-
// -----
62-
gpu.module @test {
63-
gpu.func @test_dpas(%arg0: vector<8x16xf16>, %arg1: vector<16x16xf16>, %arg3: vector<8x16xf32>, %arg2: memref<8x16xf32>){
64-
%c0 = arith.constant 0 : index
65-
%0 = xegpu.dpas %arg0, %arg1, %arg3 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
66-
%3 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
67-
xegpu.store_nd %0, %3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
68-
gpu.return
69-
}
70-
}
61+
// // -----
62+
// gpu.module @test {
63+
// gpu.func @test_dpas(%arg0: vector<8x16xf16>, %arg1: vector<16x16xf16>, %arg3: vector<8x16xf32>, %arg2: memref<8x16xf32>){
64+
// %c0 = arith.constant 0 : index
65+
// %0 = xegpu.dpas %arg0, %arg1, %arg3 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
66+
// %3 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
67+
// xegpu.store_nd %0, %3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
68+
// gpu.return
69+
// }
70+
// }

0 commit comments

Comments
 (0)