Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 24 additions & 25 deletions mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@ static constexpr int32_t executionSize{16};

// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
enum class NdTdescOffset : uint32_t {
BasePtr = 0, // Base pointer (i64)
BaseShapeW = 2, // Base shape width (i32)
BaseShapeH = 3, // Base shape height (i32)
TensorOffsetW = 4, // Tensor offset W (i32)
TensorOffsetH = 5 // Tensor offset H (i32)
BasePtr = 0, // Base pointer (i64)
BaseShapeW = 2, // Base shape width (i32)
BaseShapeH = 3, // Base shape height (i32)
BasePitch = 4, // Base pitch (i32)
};

static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
Expand Down Expand Up @@ -179,11 +178,10 @@ class CreateNdDescToXeVMPattern
Value baseAddr;
Value baseShapeW;
Value baseShapeH;
Value offsetW;
Value offsetH;

// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
Copy link
Contributor

@Jianhui-Li Jianhui-Li Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should differentiate the source being a memref or a pointer. For pointer, user is expected to provide both shapes and strides, so the above code works fine.
But for memref source, user may not know the stride, the code should extract the strides from memref. For dynamic shape memref, this will trigger the ExtractStridedMetadataOp again (after the one in the type conversion to get base addr and offset) but I guess it should be removed in the llvm level.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if ranked dynamic memref triggers ExtractStrideMetadataOp, lowering will clean up and allow direct access to relevant fields from lowered and decomposed memref.
See

(python-3.9) jovyan@jupyter-silee2:~/Projects/llvm-project [main|⚑ 29]$ cat strided.mlir
module {
  func.func @test(%arg0: memref<?x?xf32>) -> (index) {
    %base, %offset, %sizes:2, %strides:2 =
      memref.extract_strided_metadata %arg0 : memref<?x?xf32>
        -> memref<f32>, index, index, index, index, index
        return %strides#0 : index
  }
}
(python-3.9) jovyan@jupyter-silee2:~/Projects/llvm-project [main|⚑ 29]$ ./build/bin/mlir-opt --convert-to-llvm -canonicalize strided.mlir
module {
  llvm.func @test(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) -> i64 {
    llvm.return %arg5 : i64
  }
}

You can see that stride is forwarded directly from kernel arg, which is lowered and unpacked from memref.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am not mistaken for dynamic memrefs this still get the strides from createNd's own parameters, not memref's right?

Is that done in a seperate PR?

@Jianhui-Li Seems like this PR also still get the strides from CreateNd? So in that case is it fine to move ahead with #170218? And we can fix the whole thing (take strides using ExtractMetaOp) and remove shape, strided from createNd in a new PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The PR does not address the issue of Dynamic memrefs. Dynamic memref will be handled in a separate PR.

// Descriptor shape is expected to be 2D.
int64_t rank = mixedSizes.size();
auto sourceTy = source.getType();
Expand Down Expand Up @@ -216,12 +214,11 @@ class CreateNdDescToXeVMPattern
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
return val;
};
// Offsets are not supported (0 is used).
offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
// Get shape values from op fold results.
baseShapeW = createOffset(mixedSizes, 1);
baseShapeH = createOffset(mixedSizes, 0);
// Get pitch value from op fold results.
Value basePitch = createOffset(mixedStrides, 0);
// Populate payload.
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
Expand All @@ -235,12 +232,9 @@ class CreateNdDescToXeVMPattern
payload =
vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
static_cast<int>(NdTdescOffset::BaseShapeH));
payload = vector::InsertOp::create(
rewriter, loc, offsetW, payload,
static_cast<int>(NdTdescOffset::TensorOffsetW));
payload = vector::InsertOp::create(
rewriter, loc, offsetH, payload,
static_cast<int>(NdTdescOffset::TensorOffsetH));
payload =
vector::InsertOp::create(rewriter, loc, basePitch, payload,
static_cast<int>(NdTdescOffset::BasePitch));
rewriter.replaceOp(op, payload);
return success();
}
Expand Down Expand Up @@ -289,6 +283,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
Value baseShapeH = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
Value basePitch = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
// Offsets are provided by the op.
// convert them to i32.
Value offsetW =
Expand All @@ -303,8 +299,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
// Compute width in bytes.
Value surfaceW =
Value baseWidthByte =
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
// Compute pitch in bytes.
Value basePitchByte =
arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);

// Get tile width from the tensor descriptor type.
auto tileW = tdescTy.getDimSize(tileRank - 1);
Expand All @@ -331,18 +330,18 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
xevm::BlockStore2dOp::create(
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
offsetH, elemBitSize, tileW, tileH, src,
rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, src,
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
rewriter.eraseOp(op);
} else {
auto loadCacheControl =
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
xevm::BlockPrefetch2dOp::create(
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW,
offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
rewriter.eraseOp(op);
} else {
VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
Expand All @@ -355,9 +354,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
: rewriter.getIntegerType(elemBitSize));

Value resultFlatVec = xevm::BlockLoad2dOp::create(
rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
transpose, vnni,
rewriter, loc, loadedTy, basePtrLLVM, baseWidthByte, baseShapeH,
basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
vblocks, transpose, vnni,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
resultFlatVec = vector::BitCastOp::create(
rewriter, loc,
Expand Down
37 changes: 16 additions & 21 deletions mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,19 @@ gpu.module @create_nd_tdesc {
gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
%stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
// CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
// CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
// CHECK: %[[DYN_ADDR:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32
// CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32
// CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32
// CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32
// CHECK: %[[PITCH:.*]] = arith.index_cast %[[ARG4]] : index to i32
// CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32>
// CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32>
// CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32>
// CHECK: %[[VAR11:.*]] = vector.insert %[[PITCH]], %[[VAR10]] [4] : i32 into vector<8xi32>
%ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
: ui64 -> !xegpu.tensor_desc<8x16xf32>

Expand All @@ -32,19 +30,18 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
// CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
// CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
// CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
// CHECK: %[[C32_I64_2:.*]] = arith.constant 32 : i64
// CHECK: %[[PITCH2:.*]] = arith.trunci %[[C32_I64_2]] : i64 to i32
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
// CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
// CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
// CHECK: %[[VAR19:.*]] = vector.insert %[[PITCH2]], %[[VAR18]] [4] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>

// CHECK: %[[C1:.*]] = arith.constant 1 : index
Expand All @@ -53,18 +50,16 @@ gpu.module @create_nd_tdesc {
%size_x = arith.constant 64 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index
%BLOCK_DMODEL = arith.constant 16 : index
// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
// CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
// CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
// CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
// CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>
// CHECK: %[[VAR29:.*]] = vector.insert %[[C0_I32_6]], %[[VAR28]] [4] : i32 into vector<8xi32>
// CHECK: %[[VAR30:.*]] = vector.insert %[[C0_I32_7]], %[[VAR29]] [5] : i32 into vector<8xi32>
// CHECK: %[[CST_3:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[SHAPE_W3:.*]] = arith.index_cast %[[C16]] : index to i32
// CHECK: %[[SHAPE_H3:.*]] = arith.index_cast %[[C64]] : index to i32
// CHECK: %[[PITCH3:.*]] = arith.index_cast %[[C16]] : index to i32
// CHECK: %[[VAR25:.*]] = vector.bitcast %[[CST_3]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR26:.*]] = vector.insert %[[DYN_ADDR]], %[[VAR25]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR27:.*]] = vector.bitcast %[[VAR26]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR28:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR27]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR28]] [3] : i32 into vector<8xi32>
// CHECK: %[[VAR30:.*]] = vector.insert %[[PITCH3]], %[[VAR29]] [4] : i32 into vector<8xi32>
%dyn_tdesc = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16>
gpu.return
}
Expand Down
60 changes: 7 additions & 53 deletions mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
Original file line number Diff line number Diff line change
@@ -1,78 +1,32 @@
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s

gpu.module @load_store_check {
// CHECK-LABEL: gpu.func @load_store(
// CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: memref<8x16xf32, 1>) kernel {
gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
// CHECK: %[[W_P_BYTES:.*]] = arith.constant 64 : i32
// CHECK: %[[ZERO:.*]] = arith.constant 0 : i32
// CHECK: %[[H:.*]] = arith.constant 8 : i32
%srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
%dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>

// CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
// CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[MEMSPACECAST_0:.*]] = memref.memory_space_cast %[[ARG1]] : memref<8x16xf32, 1> to memref<8x16xf32>
// CHECK: %[[INTPTR_1:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST_0]] : memref<8x16xf32> -> index
// CHECK: %[[ST_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR_1]] : index to i64
// CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
// CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
// CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
// CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
// CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
// CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>


//CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
//CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
//CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
//CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
//CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
//CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64
//CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32
//CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
//CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
//CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
//CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
//CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
//CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
//CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]]
//CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
//CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
//CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
%loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
//CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32>

%tid_x = gpu.thread_id x
%tid_x_i32 = arith.index_cast %tid_x : index to i32
%tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
//CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
%loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>

// CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
// CHECK: %[[DESC_0:.*]] = vector.insert %[[ST_PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
// CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
// CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>
// CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32>
// CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
%dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>

//CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
//CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
//CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
//CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
//CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
//CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64
//CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32
//CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
//CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
//CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
//CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
//CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
//CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],
//CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]]
//CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
//CHECK: xevm.blockstore2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]], %{{.*}} <{
//CHECK-SAME: cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
//CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
: vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
Expand Down
Loading