Skip to content
39 changes: 34 additions & 5 deletions mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,47 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
xegpu::TensorDescType descType,
TypedValue<MemRefType> src) {
MemRefType srcTy = src.getType();
assert(srcTy.isStrided() && "Expected strided memref type");
auto [strides, offset] = srcTy.getStridesAndOffset();
bool isStatic = true;

// Memref is dynamic if any of its shape, offset or strides is dynamic.
if (!srcTy.hasStaticShape()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: braces can be skipped

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed braces.

isStatic = false;
}

if (offset == ShapedType::kDynamic)
isStatic = false;

for (auto stride : strides) {
if (stride == ShapedType::kDynamic) {
isStatic = false;
break;
}
}

xegpu::CreateNdDescOp ndDesc;
if (srcTy.hasStaticShape()) {
if (isStatic) {
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
} else {
// In case of any dynamic shapes, source's shape and strides have to be
// In case of ranked dynamic memref, instead of passing on the memref,
// i64 base address, source's offset, shape and strides have to be
// explicitly provided.
auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
meta.getConstifiedMixedSizes(),
meta.getConstifiedMixedStrides());
auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
rewriter, loc, meta.getBaseBuffer());
auto baseAddrI64 = arith::IndexCastOp::create(
rewriter, loc, rewriter.getI64Type(), baseAddrIndex.getResult());
// Strided metadata only provides 1D offset but create_nd_desc op expect
// offset match the rank of source memref. Add leading zeros if rank > 1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 1d offset needs to be added to the baseAddrI64 here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated PR to generate adjusted base addr: base addr + offset * element_size_in_bytes

SmallVector<OpFoldResult> fullOffsets;
for (unsigned i = 0; i < srcTy.getRank() - 1; ++i) {
fullOffsets.push_back(rewriter.getI64IntegerAttr(0));
}
fullOffsets.push_back(meta.getConstifiedMixedOffset());
ndDesc = xegpu::CreateNdDescOp::create(
rewriter, loc, descType, baseAddrI64, fullOffsets,
meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
}

return ndDesc;
Expand Down
26 changes: 18 additions & 8 deletions mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[COLLAPSED]]
// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
// CHECK: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
// CHECK-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// CHECK-SAME: : memref<f32> -> index
// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][%[[OFFSET1]]], shape : [32],
// CHECK-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
// CHECK: return %[[VEC]]
Expand All @@ -30,9 +34,12 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[COLLAPSED]]
// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// CHECK-SAME: : memref<f32> -> index
// CHECK: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
// CHECK-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]

Expand All @@ -49,8 +56,11 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
// CHECK: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
// CHECK-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]

Expand Down
26 changes: 18 additions & 8 deletions mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[COLLAPSED]]
// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
// CHECK: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
// CHECK-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// CHECK-SAME: : memref<f32> -> index
// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][%[[OFFSET1]]], shape : [32],
// CHECK-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>

Expand All @@ -32,9 +36,12 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[COLLAPSED]]
// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// CHECK-SAME: : memref<f32> -> index
// CHECK: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
// CHECK-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>

// -----
Expand All @@ -51,8 +58,11 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
// CHECK: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
// CHECK-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>

// -----
Expand Down
32 changes: 22 additions & 10 deletions mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,12 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
// LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
// LOAD-ND-SAME: %[[COLLAPSED]]
// LOAD-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// LOAD-ND-SAME: : memref<f32> -> index
// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
// LOAD-ND-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
// LOAD-ND-SAME: boundary_check = false
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
Expand Down Expand Up @@ -148,8 +151,12 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
// LOAD-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
// LOAD-ND-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
// LOAD-ND-SAME: #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]

Expand Down Expand Up @@ -185,7 +192,11 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x8x16xf32>,
// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32, strided<[16, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
// LOAD-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [8, 16], strides : [16, 1] :
// LOAD-ND-SAME: i64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
// LOAD-ND: return %[[VEC]] : vector<8x16xf32>

Expand Down Expand Up @@ -460,10 +471,11 @@ gpu.func @load_from_subview_2D(%source: memref<4096x4096xf16>, %off1: index, %of
// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
// LOAD-ND-SAME: %[[SUBVIEW]]
// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf16,
// LOAD-ND-SAME: boundary_check = false
// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[SUBVIEW]]
// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
// LOAD-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [256, 256], strides : [4096, 1] :
// LOAD-ND-SAME: i64 -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8x16xf16>
// LOAD-ND: return %[[VEC]]

Expand Down
35 changes: 23 additions & 12 deletions mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
// STORE-ND-SAME: %[[COLLAPSED]]
// STORE-ND-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
// STORE-ND: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
// STORE-ND-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
// STORE-ND: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// STORE-ND-SAME: : memref<f32> -> index
// STORE-ND: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][%[[OFFSET1]]], shape : [32],
// STORE-ND-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32,
// STORE-ND-SAME: boundary_check = false
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>

Expand Down Expand Up @@ -51,9 +55,12 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
// STORE-ND-SAME: %[[COLLAPSED]]
// STORE-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
// STORE-ND: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// STORE-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// STORE-ND-SAME: : memref<f32> -> index
// STORE-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
// STORE-ND-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
// STORE-ND-SAME: boundary_check = false
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>

Expand Down Expand Up @@ -87,8 +94,11 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// STORE-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
// STORE-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
// STORE-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
// STORE-ND: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
// STORE-ND: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
// STORE-ND-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>

// STORE-SCATTER-LABEL: @store_dynamic_source(
Expand Down Expand Up @@ -295,10 +305,11 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
// STORE-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
// STORE-ND-SAME: %[[COLLAPSED]]
// STORE-ND-SAME: memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
// STORE-ND-SAME: boundary_check = false
// STORE-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[COLLAPSED]]
// STORE-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
// STORE-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
// STORE-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %0[%[[OFFSET]]], shape : [256], strides : [1] : i64 ->
// STORE-ND-SAME: !xegpu.tensor_desc<8xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF2]]] : vector<8xf16>

// STORE-SCATTER-LABEL: @store_to_subview(
Expand Down