Skip to content

Commit 76f7323

Browse files
committed
save work
1 parent 9d0341d commit 76f7323

File tree

2 files changed

+76
-4
lines changed

2 files changed

+76
-4
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeTranspose.cpp

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1313
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
1414
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
15+
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
1516
#include "mlir/Transforms/DialectConversion.h"
1617
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18+
#include <optional>
1719

1820
namespace mlir {
1921
namespace xegpu {
@@ -29,20 +31,36 @@ using namespace mlir;
2931

3032
namespace {
3133

32-
class XeGPULoadNdPattern final : public OpConversionPattern<xegpu::LoadNdOp> {
34+
static std::optional<SmallVector<int64_t>>
35+
get2DLaneData(xegpu::TensorDescType tdescType) {
36+
auto layout = tdescType.getLayoutAttr();
37+
if (!layout)
38+
return std::nullopt;
39+
auto laneData = layout.getEffectiveLaneDataAsInt();
40+
if (laneData.size() != 2)
41+
return std::nullopt;
42+
return laneData;
43+
}
44+
45+
class XeGPUCreateNdDescOpPattern final
46+
: public OpConversionPattern<xegpu::CreateNdDescOp> {
3347
public:
34-
using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
48+
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
3549
LogicalResult
36-
matchAndRewrite(xegpu::LoadNdOp loadOp, OpAdaptor adaptor,
50+
matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor,
3751
ConversionPatternRewriter &rewriter) const override {
52+
auto tdescTy = createNdOp.getType();
53+
auto convertType = this->getTypeConverter()->convertType(tdescTy);
54+
if (convertType == tdescTy)
55+
return failure();
3856
return success();
3957
}
4058
};
4159
} // namespace
4260

4361
void xegpu::populateXeGPUOptimizeTransposePatterns(
4462
RewritePatternSet &patterns) {
45-
patterns.add<XeGPULoadNdPattern>(patterns.getContext());
63+
patterns.add<XeGPUCreateNdDescOpPattern>(patterns.getContext());
4664
}
4765

4866
namespace {
@@ -55,6 +73,42 @@ struct XeGPUOptimizeTransposePass final
5573
TypeConverter converter;
5674
RewritePatternSet patterns(&context);
5775
ConversionTarget target(context);
76+
77+
target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
78+
[](xegpu::CreateNdDescOp createNdOp) {
79+
auto optionalLaneData = get2DLaneData(createNdOp.getType());
80+
if (!optionalLaneData)
81+
return true;
82+
auto laneData = optionalLaneData.value();
83+
return laneData[0] != 1 || laneData[1] == 1;
84+
});
85+
86+
converter.addConversion([](xegpu::TensorDescType tdescType) {
87+
auto optionalLaneData = get2DLaneData(tdescType);
88+
if (!optionalLaneData)
89+
return tdescType;
90+
auto laneData = optionalLaneData.value();
91+
int64_t innerLaneData = laneData[1];
92+
if (laneData[0] == 1 && innerLaneData != 1) {
93+
int elementTyBitwidth =
94+
tdescType.getElementType().getIntOrFloatBitWidth();
95+
assert(elementTyBitwidth < 32 &&
96+
"Expected element type bitwidth < 32 with laneData[1] != 1");
97+
SmallVector<int64_t> newShape(tdescType.getShape());
98+
newShape.back() = newShape.back() / innerLaneData;
99+
Type newElemTy = IntegerType::get(tdescType.getContext(),
100+
elementTyBitwidth * innerLaneData);
101+
xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
102+
tdescType.getContext(),
103+
tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1});
104+
return xegpu::TensorDescType::get(
105+
newShape, newElemTy, tdescType.getArrayLength(),
106+
tdescType.getBoundaryCheck(), tdescType.getMemorySpace(),
107+
newLayout);
108+
}
109+
return tdescType;
110+
});
111+
58112
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
59113
target);
60114
xegpu::populateXeGPUOptimizeTransposePatterns(patterns);
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
func.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {
2+
%c0 = arith.constant 0 : index
3+
%c16 = arith.constant 16 : index
4+
%c256 = arith.constant 256 : index
5+
%0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
6+
%1 = xegpu.load_nd %0[%c0, %c0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf32>
7+
%2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
8+
%3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
9+
%4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) {
10+
%5 = xegpu.load_nd %2[%c0, %arg3] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
11+
%6 = xegpu.load_nd %3[%c0, %arg3] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>> -> vector<16x16xf16>
12+
%7 = vector.transpose %6, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : vector<16x16xf16> to vector<16x16xf16>
13+
%8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
14+
scf.yield %8 : vector<8x16xf32>
15+
} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
16+
xegpu.store_nd %4, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
17+
return
18+
}

0 commit comments

Comments
 (0)