Skip to content

Commit c7a59a9

Browse files
committed
Address reviewer comments.
1 parent b2c9ffa commit c7a59a9

File tree

5 files changed

+75
-90
lines changed

5 files changed

+75
-90
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 56 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,6 @@ class CreateNdDescToXeVMPattern
186186
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
187187
// Descriptor shape is expected to be 2D.
188188
int64_t rank = mixedSizes.size();
189-
if (rank != 2)
190-
return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
191-
192189
auto sourceTy = source.getType();
193190
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
194191
// If source is a memref, we need to extract the aligned pointer as index.
@@ -199,8 +196,19 @@ class CreateNdDescToXeVMPattern
199196
}
200197
baseAddr =
201198
memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
199+
// Cast index to i64.
200+
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
202201
} else {
203202
baseAddr = adaptor.getSource();
203+
if (baseAddr.getType() != i64Ty) {
204+
// Pointer type may be i32. Cast to i64 if needed.
205+
baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
206+
}
207+
}
208+
// 1D tensor descriptor is just the base address.
209+
if (rank == 1) {
210+
rewriter.replaceOp(op, baseAddr);
211+
return success();
204212
}
205213
// Utility for creating offset values from op fold result.
206214
auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
@@ -215,13 +223,6 @@ class CreateNdDescToXeVMPattern
215223
// Get shape values from op fold results.
216224
baseShapeW = createOffset(mixedSizes, 1);
217225
baseShapeH = createOffset(mixedSizes, 0);
218-
if (sourceMemrefTy) {
219-
// Cast index to i64.
220-
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
221-
} else if (baseAddr.getType() != i64Ty) {
222-
// Pointer type may be i32. Cast to i64 if needed.
223-
baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
224-
}
225226
// Populate payload.
226227
Value payLoadAsI64 =
227228
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
@@ -257,57 +258,57 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
257258
ConversionPatternRewriter &rewriter) const override {
258259
auto mixedOffsets = op.getMixedOffsets();
259260
int64_t opOffsetsSize = mixedOffsets.size();
260-
if (opOffsetsSize != 2)
261-
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
262261
auto loc = op.getLoc();
263262
auto ctxt = rewriter.getContext();
264263

265264
auto tdesc = adaptor.getTensorDesc();
266265
auto tdescTy = op.getTensorDescType();
266+
auto tileRank = tdescTy.getRank();
267+
if (opOffsetsSize != tileRank)
268+
return rewriter.notifyMatchFailure(
269+
op, "Expected offset rank to match descriptor rank.");
267270
auto elemType = tdescTy.getElementType();
268271
auto elemBitSize = elemType.getIntOrFloatBitWidth();
269272
if (elemBitSize % 8 != 0)
270273
return rewriter.notifyMatchFailure(
271274
op, "Expected element type bit width to be multiple of 8.");
272275

273-
VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
274-
Value payLoadAsI64 =
275-
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
276-
Value basePtr = vector::ExtractOp::create(
277-
rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
278-
Value baseShapeW = vector::ExtractOp::create(
279-
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
280-
Value baseShapeH = vector::ExtractOp::create(
281-
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
282-
// Offsets are provided by the op.
283-
// convert them to i32.
284-
Value offsetW =
285-
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
286-
offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
287-
rewriter.getI32Type(), offsetW);
288-
Value offsetH =
289-
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
290-
offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
291-
rewriter.getI32Type(), offsetH);
292276
// Get address space from tensor descriptor memory space.
293277
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
294278
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
295-
// Compute element byte size.
296-
Value elemByteSize = arith::ConstantIntOp::create(
297-
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
298-
auto tileRank = tdescTy.getRank();
299-
// Get tile width from the tensor descriptor type.
300-
auto tileW = tdescTy.getDimSize(tileRank - 1);
301279
if (tileRank == 2) {
280+
// Compute element byte size.
281+
Value elemByteSize = arith::ConstantIntOp::create(
282+
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
283+
VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
284+
Value payLoadAsI64 =
285+
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
286+
Value basePtr =
287+
vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
288+
static_cast<int>(NdTdescOffset::BasePtr));
289+
Value baseShapeW = vector::ExtractOp::create(
290+
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
291+
Value baseShapeH = vector::ExtractOp::create(
292+
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
293+
// Offsets are provided by the op.
294+
// convert them to i32.
295+
Value offsetW =
296+
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
297+
offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
298+
rewriter.getI32Type(), offsetW);
299+
Value offsetH =
300+
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
301+
offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
302+
rewriter.getI32Type(), offsetH);
302303
// Convert base pointer (i64) to LLVM pointer type.
303304
Value basePtrLLVM =
304305
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
305306
// Compute width in bytes.
306-
Value elemByteSize = arith::ConstantIntOp::create(
307-
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
308307
Value surfaceW =
309308
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
310309

310+
// Get tile width from the tensor descriptor type.
311+
auto tileW = tdescTy.getDimSize(tileRank - 1);
311312
// Get tile height from the tensor descriptor type.
312313
auto tileH = tdescTy.getDimSize(0);
313314
// Get vblocks from the tensor descriptor type.
@@ -367,21 +368,23 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
367368
}
368369
}
369370
} else {
370-
// Get address from base address and offsets.
371+
// 1D tensor descriptor.
372+
// `tdesc` represents base address as i64
371373
// Offset in number of elements, need to multiply by element byte size.
372-
// Compute linear offset.
373-
// linearOffset = offsetH * baseShapeW + offsetW
374-
Value offsetHInElems =
375-
rewriter.createOrFold<arith::MulIOp>(loc, offsetH, baseShapeW);
376-
Value linearOffset =
377-
rewriter.createOrFold<arith::AddIOp>(loc, offsetHInElems, offsetW);
378-
// Then compute byte offset by multiplying with element byte size.
379-
// byteOffset = linearOffset * elemByteSize
374+
// Compute byte offset.
375+
// byteOffset = offset * elementByteSize
376+
Value offset =
377+
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
378+
offset = getValueOrCreateCastToIndexLike(rewriter, loc,
379+
rewriter.getI64Type(), offset);
380+
// Compute element byte size.
381+
Value elemByteSize = arith::ConstantIntOp::create(
382+
rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
380383
Value byteOffset =
381-
rewriter.createOrFold<arith::MulIOp>(loc, linearOffset, elemByteSize);
384+
rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
382385
// Final address = basePtr + byteOffset
383386
Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
384-
loc, basePtr,
387+
loc, tdesc,
385388
getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
386389
byteOffset));
387390
// Convert base pointer (i64) to LLVM pointer type.
@@ -992,7 +995,10 @@ struct ConvertXeGPUToXeVMPass
992995
return VectorType::get(sum, elemType);
993996
});
994997
typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
998+
// Scattered descriptors are not supported in XeVM lowering.
995999
if (type.isScattered())
1000+
return {};
1001+
if (type.getRank() == 1)
9961002
return IntegerType::get(&getContext(), 64);
9971003
auto i32Type = IntegerType::get(&getContext(), 32);
9981004
return VectorType::get(8, i32Type);

mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ gpu.module @create_nd_tdesc {
2929

3030
// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
3131
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
32+
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
3233
// CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
3334
// CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
3435
// CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
3536
// CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
3637
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
3738
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
38-
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
3939
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
4040
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
4141
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
@@ -53,11 +53,11 @@ gpu.module @create_nd_tdesc {
5353
%BLOCK_DMODEL = arith.constant 16 : index
5454
// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
5555
// CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
56+
// CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
5657
// CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
5758
// CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
5859
// CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
5960
// CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
60-
// CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
6161
// CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
6262
// CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
6363
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>

mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,55 +2,34 @@
22

33
gpu.module @load_store_check {
44
// CHECK-LABEL: @load_store(
5-
// CHECK-SAME: %[[SRC:.*]]: memref<8x64xf32, 1>, %[[DST:.*]]: memref<8x32xf32, 1>
6-
gpu.func @load_store(%src: memref<8x64xf32, 1>, %dst: memref<8x32xf32, 1>) kernel {
5+
// CHECK-SAME: %[[SRC:.*]]: memref<512xf32, 1>, %[[DST:.*]]: memref<256xf32, 1>
6+
gpu.func @load_store(%src: memref<512xf32, 1>, %dst: memref<256xf32, 1>) kernel {
77
// CHECK: %[[C512:.*]] = arith.constant 512 : i64
8-
// CHECK: %[[C32:.*]] = arith.constant 32 : i32
98
// CHECK: %[[C384:.*]] = arith.constant 384 : i64
10-
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi64>
11-
// CHECK: %[[C8:.*]] = arith.constant 8 : i32
12-
// CHECK: %[[C64:.*]] = arith.constant 64 : i32
13-
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
149

15-
// CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[SRC]] : memref<8x64xf32, 1> to memref<8x64xf32>
16-
%srcce = memref.memory_space_cast %src : memref<8x64xf32, 1> to memref<8x64xf32>
17-
// CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[DST]] : memref<8x32xf32, 1> to memref<8x32xf32>
18-
%dstte = memref.memory_space_cast %dst : memref<8x32xf32, 1> to memref<8x32xf32>
10+
// CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[SRC]] : memref<512xf32, 1> to memref<512xf32>
11+
%srcce = memref.memory_space_cast %src : memref<512xf32, 1> to memref<512xf32>
12+
// CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[DST]] : memref<256xf32, 1> to memref<256xf32>
13+
%dstte = memref.memory_space_cast %dst : memref<256xf32, 1> to memref<256xf32>
1914

20-
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<8x64xf32> -> index
15+
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<512xf32> -> index
2116
// CHECK: %[[INTPTR_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
22-
// CHECK: %[[VEC1:.*]] = vector.insert %[[INTPTR_I64]], %[[CST]] [0] : i64 into vector<4xi64>
23-
// CHECK: %[[VEC2:.*]] = vector.bitcast %[[VEC1]] : vector<4xi64> to vector<8xi32>
24-
// CHECK: %[[VEC3:.*]] = vector.insert %[[C64]], %[[VEC2]] [2] : i32 into vector<8xi32>
25-
// CHECK: %[[VEC4:.*]] = vector.insert %[[C8]], %[[VEC3]] [3] : i32 into vector<8xi32>
26-
// CHECK: %[[VEC5:.*]] = vector.insert %[[C0]], %[[VEC4]] [4] : i32 into vector<8xi32>
27-
// CHECK: %[[VEC6:.*]] = vector.insert %[[C0]], %[[VEC5]] [5] : i32 into vector<8xi32>
28-
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x64xf32> -> !xegpu.tensor_desc<32xf32>
29-
// CHECK: %[[VEC7:.*]] = vector.bitcast %[[VEC6]] : vector<8xi32> to vector<4xi64>
30-
// CHECK: %[[EXTR:.*]] = vector.extract %[[VEC7]][0] : i64 from vector<4xi64>
31-
// CHECK: %[[ADDR:.*]] = arith.addi %[[EXTR]], %[[C384]] : i64
17+
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<512xf32> -> !xegpu.tensor_desc<32xf32>
18+
// CHECK: %[[ADDR:.*]] = arith.addi %[[INTPTR_I64]], %[[C384]] : i64
3219
// CHECK: %[[PTR:.*]] = llvm.inttoptr %[[ADDR]] : i64 to !llvm.ptr<1>
3320
// CHECK: %[[LOAD:.*]] = xevm.blockload %[[PTR]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}>
3421
// CHECK-SAME: : (!llvm.ptr<1>) -> vector<2xi32>
35-
%loaded = xegpu.load_nd %src_tdesc[1, 32] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
22+
%loaded = xegpu.load_nd %src_tdesc[96] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
3623
: !xegpu.tensor_desc<32xf32> -> vector<2xf32>
3724

38-
// CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<8x32xf32> -> index
25+
// CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<256xf32> -> index
3926
// CHECK: %[[INTPTR1_I64:.*]] = arith.index_castui %[[INTPTR1]] : index to i64
40-
// CHECK: %[[VEC1_1:.*]] = vector.insert %[[INTPTR1_I64]], %[[CST]] [0] : i64 into vector<4xi64>
41-
// CHECK: %[[VEC2_1:.*]] = vector.bitcast %[[VEC1_1]] : vector<4xi64> to vector<8xi32>
42-
// CHECK: %[[VEC3_1:.*]] = vector.insert %[[C32]], %[[VEC2_1]] [2] : i32 into vector<8xi32>
43-
// CHECK: %[[VEC4_1:.*]] = vector.insert %[[C8]], %[[VEC3_1]] [3] : i32 into vector<8xi32>
44-
// CHECK: %[[VEC5_1:.*]] = vector.insert %[[C0]], %[[VEC4_1]] [4] : i32 into vector<8xi32>
45-
// CHECK: %[[VEC6_1:.*]] = vector.insert %[[C0]], %[[VEC5_1]] [5] : i32 into vector<8xi32>
46-
%dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>>
47-
// CHECK: %[[VEC7_1:.*]] = vector.bitcast %[[VEC6_1]] : vector<8xi32> to vector<4xi64>
48-
// CHECK: %[[EXTR1:.*]] = vector.extract %[[VEC7_1]][0] : i64 from vector<4xi64>
49-
// CHECK: %[[ADDR1:.*]] = arith.addi %[[EXTR1]], %[[C512]] : i64
27+
%dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<256xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>>
28+
// CHECK: %[[ADDR1:.*]] = arith.addi %[[INTPTR1_I64]], %[[C512]] : i64
5029
// CHECK: %[[PTR1:.*]] = llvm.inttoptr %[[ADDR1]] : i64 to !llvm.ptr<1>
5130
// CHECK: xevm.blockstore %[[PTR1]], %[[LOAD]] <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}>
5231
// CHECK-SAME: : (!llvm.ptr<1>, vector<2xi32>)
53-
xegpu.store_nd %loaded, %dst_tdesc[4, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
32+
xegpu.store_nd %loaded, %dst_tdesc[128] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
5433
: vector<2xf32>, !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>>
5534
gpu.return
5635
}

mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ gpu.module @load_store_check {
1616
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
1717

1818

19+
//CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
1920
//CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
2021
//CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
2122
//CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
@@ -25,7 +26,6 @@ gpu.module @load_store_check {
2526
//CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
2627
//CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
2728
//CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
28-
//CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
2929
//CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
3030
//CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
3131
//CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
@@ -52,6 +52,7 @@ gpu.module @load_store_check {
5252
// CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
5353
%dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
5454

55+
//CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
5556
//CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
5657
//CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
5758
//CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
@@ -61,7 +62,6 @@ gpu.module @load_store_check {
6162
//CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
6263
//CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
6364
//CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
64-
//CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
6565
//CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
6666
//CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
6767
//CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],

0 commit comments

Comments
 (0)