Skip to content

Commit 447af32

Browse files
authored
[MLIR][XeGPU][XeVM] create_nd_tdesc: use correct pitch from strides. (#170384)
Base memory pitch should be derived from base stride, not base width. Remove offset fields from tensor descriptor payload and add pitch field.
1 parent b8ddbc4 commit 447af32

File tree

4 files changed

+52
-115
lines changed

4 files changed

+52
-115
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,10 @@ static constexpr int32_t executionSize{16};
5050

5151
// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
5252
enum class NdTdescOffset : uint32_t {
53-
BasePtr = 0, // Base pointer (i64)
54-
BaseShapeW = 2, // Base shape width (i32)
55-
BaseShapeH = 3, // Base shape height (i32)
56-
TensorOffsetW = 4, // Tensor offset W (i32)
57-
TensorOffsetH = 5 // Tensor offset H (i32)
53+
BasePtr = 0, // Base pointer (i64)
54+
BaseShapeW = 2, // Base shape width (i32)
55+
BaseShapeH = 3, // Base shape height (i32)
56+
BasePitch = 4, // Base pitch (i32)
5857
};
5958

6059
static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
@@ -179,11 +178,10 @@ class CreateNdDescToXeVMPattern
179178
Value baseAddr;
180179
Value baseShapeW;
181180
Value baseShapeH;
182-
Value offsetW;
183-
Value offsetH;
184181

185182
// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
186183
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
184+
SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
187185
// Descriptor shape is expected to be 2D.
188186
int64_t rank = mixedSizes.size();
189187
auto sourceTy = source.getType();
@@ -216,12 +214,11 @@ class CreateNdDescToXeVMPattern
216214
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
217215
return val;
218216
};
219-
// Offsets are not supported (0 is used).
220-
offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
221-
offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
222217
// Get shape values from op fold results.
223218
baseShapeW = createOffset(mixedSizes, 1);
224219
baseShapeH = createOffset(mixedSizes, 0);
220+
// Get pitch value from op fold results.
221+
Value basePitch = createOffset(mixedStrides, 0);
225222
// Populate payload.
226223
Value payLoadAsI64 =
227224
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
@@ -235,12 +232,9 @@ class CreateNdDescToXeVMPattern
235232
payload =
236233
vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
237234
static_cast<int>(NdTdescOffset::BaseShapeH));
238-
payload = vector::InsertOp::create(
239-
rewriter, loc, offsetW, payload,
240-
static_cast<int>(NdTdescOffset::TensorOffsetW));
241-
payload = vector::InsertOp::create(
242-
rewriter, loc, offsetH, payload,
243-
static_cast<int>(NdTdescOffset::TensorOffsetH));
235+
payload =
236+
vector::InsertOp::create(rewriter, loc, basePitch, payload,
237+
static_cast<int>(NdTdescOffset::BasePitch));
244238
rewriter.replaceOp(op, payload);
245239
return success();
246240
}
@@ -289,6 +283,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
289283
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
290284
Value baseShapeH = vector::ExtractOp::create(
291285
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
286+
Value basePitch = vector::ExtractOp::create(
287+
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
292288
// Offsets are provided by the op.
293289
// convert them to i32.
294290
Value offsetW =
@@ -303,8 +299,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
303299
Value basePtrLLVM =
304300
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
305301
// Compute width in bytes.
306-
Value surfaceW =
302+
Value baseWidthByte =
307303
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
304+
// Compute pitch in bytes.
305+
Value basePitchByte =
306+
arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
308307

309308
// Get tile width from the tensor descriptor type.
310309
auto tileW = tdescTy.getDimSize(tileRank - 1);
@@ -331,18 +330,18 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
331330
auto storeCacheControl =
332331
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
333332
xevm::BlockStore2dOp::create(
334-
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
335-
offsetH, elemBitSize, tileW, tileH, src,
333+
rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
334+
basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, src,
336335
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
337336
rewriter.eraseOp(op);
338337
} else {
339338
auto loadCacheControl =
340339
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
341340
if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
342341
xevm::BlockPrefetch2dOp::create(
343-
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW,
344-
offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
345-
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
342+
rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
343+
basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
344+
vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
346345
rewriter.eraseOp(op);
347346
} else {
348347
VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
@@ -355,9 +354,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
355354
: rewriter.getIntegerType(elemBitSize));
356355

357356
Value resultFlatVec = xevm::BlockLoad2dOp::create(
358-
rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
359-
surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
360-
transpose, vnni,
357+
rewriter, loc, loadedTy, basePtrLLVM, baseWidthByte, baseShapeH,
358+
basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
359+
vblocks, transpose, vnni,
361360
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
362361
resultFlatVec = vector::BitCastOp::create(
363362
rewriter, loc,

mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,19 @@ gpu.module @create_nd_tdesc {
88
gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
99
%stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
1010
// CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
11-
// CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
11+
// CHECK: %[[DYN_ADDR:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
1212
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
1313
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
1414
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
15-
// CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32
16-
// CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32
1715
// CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32
1816
// CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32
17+
// CHECK: %[[PITCH:.*]] = arith.index_cast %[[ARG4]] : index to i32
1918
// CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64>
2019
// CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64>
2120
// CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32>
2221
// CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32>
2322
// CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32>
24-
// CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32>
25-
// CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32>
23+
// CHECK: %[[VAR11:.*]] = vector.insert %[[PITCH]], %[[VAR10]] [4] : i32 into vector<8xi32>
2624
%ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
2725
: ui64 -> !xegpu.tensor_desc<8x16xf32>
2826

@@ -32,19 +30,18 @@ gpu.module @create_nd_tdesc {
3230
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
3331
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
3432
// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
35-
// CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
36-
// CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
3733
// CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
3834
// CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
3935
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
4036
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
37+
// CHECK: %[[C32_I64_2:.*]] = arith.constant 32 : i64
38+
// CHECK: %[[PITCH2:.*]] = arith.trunci %[[C32_I64_2]] : i64 to i32
4139
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
4240
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
4341
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
4442
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
4543
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
46-
// CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
47-
// CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
44+
// CHECK: %[[VAR19:.*]] = vector.insert %[[PITCH2]], %[[VAR18]] [4] : i32 into vector<8xi32>
4845
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
4946

5047
// CHECK: %[[C1:.*]] = arith.constant 1 : index
@@ -53,18 +50,16 @@ gpu.module @create_nd_tdesc {
5350
%size_x = arith.constant 64 : index
5451
// CHECK: %[[C16:.*]] = arith.constant 16 : index
5552
%BLOCK_DMODEL = arith.constant 16 : index
56-
// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
57-
// CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
58-
// CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
59-
// CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
60-
// CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
61-
// CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
62-
// CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
63-
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
64-
// CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
65-
// CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>
66-
// CHECK: %[[VAR29:.*]] = vector.insert %[[C0_I32_6]], %[[VAR28]] [4] : i32 into vector<8xi32>
67-
// CHECK: %[[VAR30:.*]] = vector.insert %[[C0_I32_7]], %[[VAR29]] [5] : i32 into vector<8xi32>
53+
// CHECK: %[[CST_3:.*]] = arith.constant dense<0> : vector<8xi32>
54+
// CHECK: %[[SHAPE_W3:.*]] = arith.index_cast %[[C16]] : index to i32
55+
// CHECK: %[[SHAPE_H3:.*]] = arith.index_cast %[[C64]] : index to i32
56+
// CHECK: %[[PITCH3:.*]] = arith.index_cast %[[C16]] : index to i32
57+
// CHECK: %[[VAR25:.*]] = vector.bitcast %[[CST_3]] : vector<8xi32> to vector<4xi64>
58+
// CHECK: %[[VAR26:.*]] = vector.insert %[[DYN_ADDR]], %[[VAR25]] [0] : i64 into vector<4xi64>
59+
// CHECK: %[[VAR27:.*]] = vector.bitcast %[[VAR26]] : vector<4xi64> to vector<8xi32>
60+
// CHECK: %[[VAR28:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR27]] [2] : i32 into vector<8xi32>
61+
// CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR28]] [3] : i32 into vector<8xi32>
62+
// CHECK: %[[VAR30:.*]] = vector.insert %[[PITCH3]], %[[VAR29]] [4] : i32 into vector<8xi32>
6863
%dyn_tdesc = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16>
6964
gpu.return
7065
}

mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir

Lines changed: 7 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,32 @@
1-
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
1+
// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
22

33
gpu.module @load_store_check {
44
// CHECK-LABEL: gpu.func @load_store(
5-
// CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: memref<8x16xf32, 1>) kernel {
65
gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
6+
// CHECK: %[[W_P_BYTES:.*]] = arith.constant 64 : i32
7+
// CHECK: %[[ZERO:.*]] = arith.constant 0 : i32
8+
// CHECK: %[[H:.*]] = arith.constant 8 : i32
79
%srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
810
%dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
911

10-
// CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
11-
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
12-
// CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
13-
// CHECK: %[[MEMSPACECAST_0:.*]] = memref.memory_space_cast %[[ARG1]] : memref<8x16xf32, 1> to memref<8x16xf32>
14-
// CHECK: %[[INTPTR_1:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST_0]] : memref<8x16xf32> -> index
15-
// CHECK: %[[ST_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR_1]] : index to i64
16-
// CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
17-
// CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
18-
// CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
19-
// CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
20-
// CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
21-
// CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
22-
// CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
2312
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
2413

25-
26-
//CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
27-
//CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
28-
//CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
29-
//CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
30-
//CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
31-
//CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64
32-
//CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32
33-
//CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
34-
//CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
35-
//CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
36-
//CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
37-
//CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
38-
//CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
14+
//CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]]
3915
//CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
4016
//CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
4117
//CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
4218
%loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
4319
: !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
44-
//CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32>
4520

4621
%tid_x = gpu.thread_id x
4722
%tid_x_i32 = arith.index_cast %tid_x : index to i32
4823
%tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
49-
//CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
5024
%loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
5125

52-
// CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
53-
// CHECK: %[[DESC_0:.*]] = vector.insert %[[ST_PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
54-
// CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
55-
// CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
56-
// CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>
57-
// CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32>
58-
// CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
5926
%dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
6027

61-
//CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
62-
//CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
63-
//CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
64-
//CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
65-
//CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
66-
//CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64
67-
//CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32
68-
//CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
69-
//CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
70-
//CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
71-
//CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
72-
//CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
73-
//CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],
74-
//CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]]
75-
//CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
28+
//CHECK: xevm.blockstore2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]], %{{.*}} <{
29+
//CHECK-SAME: cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
7630
//CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
7731
xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
7832
: vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>

0 commit comments

Comments
 (0)