Skip to content

Commit e510643

Browse files
committed
Fix bugs and add test case for high rank base memref.
1 parent 4a92953 commit e510643

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,15 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
152152
}
153153

154154
// Compute the product of sizes in the range [lo, hi) from the sizes array.
155+
// Note: all sizes are i64.
155156
static Value getProductOfSizes(ConversionPatternRewriter &rewriter,
156157
Location loc, ArrayRef<OpFoldResult> sizes,
157158
size_t lo, size_t hi) {
158-
Type indexTy = rewriter.getIndexType();
159-
Value product = arith::ConstantIndexOp::create(rewriter, loc, 1);
159+
Value product =
160+
arith::ConstantIntOp::create(rewriter, loc, rewriter.getI64Type(), 1);
160161
for (size_t idx = lo; idx < hi; idx++) {
161162
OpFoldResult ofr = sizes[idx];
162163
Value sizeVal = getValueOrCreateConstantIntOp(rewriter, loc, ofr);
163-
sizeVal = getValueOrCreateCastToIndexLike(rewriter, loc, indexTy, sizeVal);
164164
product = rewriter.createOrFold<arith::MulIOp>(loc, product, sizeVal);
165165
}
166166
return product;
@@ -233,6 +233,8 @@ class CreateNdDescToXeVMPattern
233233
// Generate compute chain for height (product of sizes of all but the last
234234
// dimension).
235235
baseShapeH = getProductOfSizes(rewriter, loc, mixedSizes, 0, srcRank - 1);
236+
baseShapeH = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy,
237+
baseShapeH);
236238
}
237239
if (sourceMemrefTy) {
238240
// Cast index to i64.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
2+
3+
gpu.module @load_store_check {
4+
// CHECK: fail
5+
gpu.func @load_store(%src: memref<3x3x8x16xf32, 1>, %dst: memref<3x3x8x16xf32, 1>) kernel {
6+
%srcce = memref.memory_space_cast %src : memref<3x3x8x16xf32, 1> to memref<3x3x8x16xf32>
7+
%dstte = memref.memory_space_cast %dst : memref<3x3x8x16xf32, 1> to memref<3x3x8x16xf32>
8+
9+
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<3x3x8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
10+
11+
%loaded = xegpu.load_nd %src_tdesc[2, 2, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
12+
: !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
13+
14+
%tid_x = gpu.thread_id x
15+
%tid_x_i32 = arith.index_cast %tid_x : index to i32
16+
%tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
17+
%loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
18+
19+
%dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<3x3x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
20+
21+
xegpu.store_nd %loaded_modified, %dst_tdesc[1, 1, 0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
22+
: vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
23+
gpu.return
24+
}
25+
}

0 commit comments

Comments
 (0)