Skip to content

Commit e0399ac

Browse files
committed
add 1D unit tests
1 parent 06cf9b2 commit e0399ac

File tree

3 files changed

+74
-14
lines changed

3 files changed

+74
-14
lines changed

mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,23 @@ void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
4242
/// Appends patterns for XeGPU SIMT distribution into `patterns`.
4343
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
4444

45-
/// Collect a set of pattern to unroll xegpu operations to a smaller shapes.
45+
/// Collect a set of patterns to unroll xegpu operations to a smaller shapes.
4646
/// Users can control whether an operation to be unrolled or not, as well as
47-
/// the its target shape via `options` structure. (via setting filterConstraint
47+
/// its target shape via `options` structure. (via setting filterConstraint
4848
/// and nativeShape respectively, both of them are function refs taking `op` as
4949
/// the input).
5050
/// An `op` is unrolled to the `targetShape` as follows, for each of its
5151
/// operands:
5252
/// 1. the unrolled type `unrolledType` and number of unrolled instances
5353
/// `numUnrolledInstances` are computed from the `targetShape`.
54-
/// 2. ExtractStridedSlice are created to break-up the vector operands. And
55-
/// BuildinUnrealizedCastop are created to break-up the TensorDesc operands.
54+
/// 2. pack each operand. ExtractStridedSlice are created to break-up the
55+
/// vector operands. And BuiltinUnrealizedCastop are created to break-up
56+
/// the TensorDesc operands.
5657
/// 3. the original op is cloned `numUnrolledInstances` times, once for each
5758
/// result.
58-
/// 4. InsertStridedSlice are inserted for VectorType result, and
59-
/// BuildinUnrealizedCastOp are inserted for TensorDescType result to
60-
/// re-assemble the slices into the original shape.
59+
/// 4. unpack the results. InsertStridedSlice are inserted for VectorType
60+
/// result, and BuiltinUnrealizedCastOp are inserted for TensorDescType result
61+
/// to re-assemble the slices into the original shape.
6162
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns,
6263
const UnrollOptions &options);
6364

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
105105
Value unpack(ValueRange srcs, Type destTy, llvm::ArrayRef<int64_t> blockSize,
106106
Location loc, PatternRewriter &rewriter) const {
107107
if (auto vecTy = dyn_cast<VectorType>(destTy)) {
108-
assert(vecTy.getRank() == 2 && blockSize.size() == 2 &&
108+
assert(vecTy.getRank() == (int64_t)blockSize.size() &&
109109
"Expecting blockSize size to match the rank of destTy.");
110110
auto shape = vecTy.getShape();
111111
auto zeroAttr = rewriter.getZeroAttr(vecTy.getElementType());
@@ -141,7 +141,7 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
141141
llvm::ArrayRef<int64_t> blockSize, Location loc,
142142
PatternRewriter &rewriter) const {
143143
if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
144-
assert(vecTy.getRank() == 2 && blockSize.size() == 2 &&
144+
assert(vecTy.getRank() == (int64_t)blockSize.size() &&
145145
"Expecting blockSize size to match the rank of src.");
146146
auto shape = vecTy.getShape();
147147
llvm::SmallVector<Value> results;
@@ -339,10 +339,6 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
339339
auto tdescTy = op.getTensorDescType();
340340
auto shape = tdescTy.getShape();
341341

342-
// TODO: enable 1D block tensor desc
343-
if (tdescTy.getRank() != 2)
344-
return failure();
345-
346342
auto maybeTargetShape = getTargetShape(op);
347343
if (!maybeTargetShape || llvm::equal(*maybeTargetShape, shape))
348344
return failure();

mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ gpu.module @test {
44

55
// CHECK-LABEL: test_create_nd_tdesc
66
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
7-
// CHECK-COUNT-6: [[data:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
7+
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
88
// CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
99
// CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>,
1010
// CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>,
@@ -17,6 +17,19 @@ gpu.module @test {
1717

1818
//-----
1919

20+
// CHECK-LABEL: test_create_nd_tdesc_1d
21+
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
22+
// CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
23+
// CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
24+
// CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
25+
// CHECK-SAME: to !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {__xetile_blocking_inner_block__ = array<i64: 16>, __xetile_blocking_unpack__}
26+
gpu.func @test_create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
27+
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
28+
gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
29+
}
30+
31+
//-----
32+
2033
// CHECK-LABEL: test_update_nd_tdesc
2134
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
2235
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
@@ -29,6 +42,18 @@ gpu.module @test {
2942

3043
//-----
3144

45+
// CHECK-LABEL: test_update_nd_tdesc_1d
46+
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
47+
// CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
48+
// CHECK-COUNT-2: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16xf32>
49+
gpu.func @test_update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
50+
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
51+
%update = xegpu.update_nd_offset %tdesc, [32] : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
52+
gpu.return %update : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
53+
}
54+
55+
//-----
56+
3257
// CHECK-LABEL: test_prefetch_nd_tdesc
3358
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
3459
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
@@ -39,6 +64,18 @@ gpu.module @test {
3964
gpu.return
4065
}
4166

67+
//-----
68+
69+
// CHECK-LABEL: test_prefetch_nd_tdesc_1d
70+
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
71+
// CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
72+
// CHECK-COUNT-4: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<16xf32>
73+
gpu.func @test_prefetch_nd_tdesc_1d(%src: memref<64xf32>) {
74+
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
75+
xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
76+
gpu.return
77+
}
78+
4279
//-----
4380
// CHECK-LABEL: test_load_nd
4481
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
@@ -53,6 +90,19 @@ gpu.module @test {
5390

5491
//-----
5592

93+
// CHECK-LABEL: test_load_nd_1d
94+
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
95+
// CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
96+
// CHECK-COUNT-4: [[ld:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
97+
// CHECK-COUNT-4: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<16xf32> into vector<64xf32>
98+
gpu.func @test_load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> {
99+
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
100+
%data = xegpu.load_nd %tdesc: !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>> -> vector<64xf32>
101+
gpu.return %data : vector<64xf32>
102+
}
103+
104+
//-----
105+
56106
// CHECK-LABEL: test_store_nd
57107
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
58108
// CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
@@ -66,6 +116,19 @@ gpu.module @test {
66116

67117
//-----
68118

119+
// CHECK-LABEL: test_store_nd_1d
120+
// CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
121+
// CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
122+
// CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32>
123+
gpu.func @test_store_nd_1d(%src: memref<64xf32>) {
124+
%tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
125+
%data = arith.constant dense<9.0> : vector<64xf32>
126+
xegpu.store_nd %data, %tdesc: vector<64xf32>, !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
127+
gpu.return
128+
}
129+
130+
//-----
131+
69132
// CHECK-LABEL: test_createNd_loadNd_storeNd
70133
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
71134
//CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>

0 commit comments

Comments
 (0)