-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR] Vector to XeGPU conversion: Use proper source variant for create_nd_tdesc op creation. #171216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
create_nd_tdesc op creation. If source strided memref is not fully static - at least one of shape, strides, offset is kDynamic - use i64 source variant.
|
@llvm/pr-subscribers-mlir-gpu Author: Sang Ik Lee (silee2) ChangesIf source strided memref is not fully static - at least one of shape, strides, offset is kDynamic - use i64 source variant. Patch is 20.16 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171216.diff 5 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 079e1e2a8ac67..b8606b261b781 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -102,18 +102,48 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
xegpu::TensorDescType descType,
TypedValue<MemRefType> src) {
MemRefType srcTy = src.getType();
+ assert(srcTy.isStrided() && "Expected strided memref type");
auto [strides, offset] = srcTy.getStridesAndOffset();
+ bool isStatic = true;
+
+ // Memref is dynamic if any of its shape, offset or strides is dynamic.
+ if (!srcTy.hasStaticShape()) {
+ isStatic = false;
+ }
+
+ if (offset == ShapedType::kDynamic)
+ isStatic = false;
+
+ for (auto stride : strides) {
+ if (stride == ShapedType::kDynamic) {
+ isStatic = false;
+ break;
+ }
+ }
xegpu::CreateNdDescOp ndDesc;
- if (srcTy.hasStaticShape()) {
+ if (isStatic) {
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
} else {
- // In case of any dynamic shapes, source's shape and strides have to be
+ // In case of ranked dynamic memref, instead of passing on the memref,
+ // i64 base address, source's offset, shape and strides have to be
// explicitly provided.
auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
- ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
- meta.getConstifiedMixedSizes(),
- meta.getConstifiedMixedStrides());
+ auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, loc, meta.getBaseBuffer());
+ auto baseAddrI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), baseAddrIndex.getResult());
+ // Strided metadata only provides 1D offset but create_nd_desc op expect
+ // offset match the rank of source memref. Add leading zeros if rank > 1.
+ srcTy.getRank();
+ SmallVector<OpFoldResult> fullOffsets;
+ for (unsigned i = 0; i < srcTy.getRank() - 1; ++i) {
+ fullOffsets.push_back(rewriter.getI64IntegerAttr(0));
+ }
+ fullOffsets.push_back(meta.getConstifiedMixedOffset());
+ ndDesc = xegpu::CreateNdDescOp::create(
+ rewriter, loc, descType, baseAddrI64, fullOffsets,
+ meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
}
return ndDesc;
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index ae5141db16c09..867d1f20fb707 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -10,9 +10,13 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[COLLAPSED]]
-// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// CHECK: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME: : memref<f32> -> index
+// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][%[[OFFSET1]]], shape : [32],
+// CHECK-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
// CHECK: return %[[VEC]]
@@ -30,9 +34,12 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[COLLAPSED]]
-// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME: : memref<f32> -> index
+// CHECK: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
+// CHECK-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
@@ -49,8 +56,11 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// CHECK-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 1a10d917623cc..09bd571951a6b 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -12,9 +12,13 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[COLLAPSED]]
-// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// CHECK: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME: : memref<f32> -> index
+// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][%[[OFFSET1]]], shape : [32],
+// CHECK-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
@@ -32,9 +36,12 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[COLLAPSED]]
-// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME: : memref<f32> -> index
+// CHECK: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
+// CHECK-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// -----
@@ -51,8 +58,11 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// CHECK-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
// -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 8bb272b1fe5fc..af330dced143e 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -49,9 +49,12 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
// LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[COLLAPSED]]
-// LOAD-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND-SAME: : memref<f32> -> index
+// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
+// LOAD-ND-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
// LOAD-ND-SAME: boundary_check = false
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
@@ -148,8 +151,12 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// LOAD-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// LOAD-ND-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND-SAME: #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
@@ -185,7 +192,11 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x8x16xf32>,
// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32, strided<[16, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// LOAD-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [8, 16], strides : [16, 1] :
+// LOAD-ND-SAME: i64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
// LOAD-ND: return %[[VEC]] : vector<8x16xf32>
@@ -460,10 +471,11 @@ gpu.func @load_from_subview_2D(%source: memref<4096x4096xf16>, %off1: index, %of
// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[SUBVIEW]]
-// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf16,
-// LOAD-ND-SAME: boundary_check = false
+// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[SUBVIEW]]
+// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// LOAD-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [256, 256], strides : [4096, 1] :
+// LOAD-ND-SAME: i64 -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8x16xf16>
// LOAD-ND: return %[[VEC]]
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index 43a1a7206e2cc..6185f8537d8e0 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -16,9 +16,13 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[COLLAPSED]]
-// STORE-ND-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// STORE-ND: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// STORE-ND: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND-SAME: : memref<f32> -> index
+// STORE-ND: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][%[[OFFSET1]]], shape : [32],
+// STORE-ND-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32,
// STORE-ND-SAME: boundary_check = false
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
@@ -51,9 +55,12 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[COLLAPSED]]
-// STORE-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
+// STORE-ND: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND-SAME: : memref<f32> -> index
+// STORE-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
+// STORE-ND-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
// STORE-ND-SAME: boundary_check = false
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
@@ -87,8 +94,11 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// STORE-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// STORE-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// STORE-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// STORE-ND: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// STORE-ND-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
// STORE-SCATTER-LABEL: @store_dynamic_source(
@@ -295,10 +305,11 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
// STORE-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
-// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[COLLAPSED]]
-// STORE-ND-SAME: memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
-// STORE-ND-SAME: boundary_check = false
+// STORE-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// STORE-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %0[%[[OFFSET]]], shape : [256], strides : [1] : i64 ->
+// STORE-ND-SAME: !xegpu.tensor_desc<8xf16, #xegpu.block_tdesc_...
[truncated]
|
|
@llvm/pr-subscribers-mlir Author: Sang Ik Lee (silee2) ChangesIf source strided memref is not fully static - at least one of shape, strides, offset is kDynamic - use i64 source variant. Patch is 20.16 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171216.diff 5 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 079e1e2a8ac67..b8606b261b781 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -102,18 +102,48 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
xegpu::TensorDescType descType,
TypedValue<MemRefType> src) {
MemRefType srcTy = src.getType();
+ assert(srcTy.isStrided() && "Expected strided memref type");
auto [strides, offset] = srcTy.getStridesAndOffset();
+ bool isStatic = true;
+
+ // Memref is dynamic if any of its shape, offset or strides is dynamic.
+ if (!srcTy.hasStaticShape()) {
+ isStatic = false;
+ }
+
+ if (offset == ShapedType::kDynamic)
+ isStatic = false;
+
+ for (auto stride : strides) {
+ if (stride == ShapedType::kDynamic) {
+ isStatic = false;
+ break;
+ }
+ }
xegpu::CreateNdDescOp ndDesc;
- if (srcTy.hasStaticShape()) {
+ if (isStatic) {
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
} else {
- // In case of any dynamic shapes, source's shape and strides have to be
+ // In case of ranked dynamic memref, instead of passing on the memref,
+ // i64 base address, source's offset, shape and strides have to be
// explicitly provided.
auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
- ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
- meta.getConstifiedMixedSizes(),
- meta.getConstifiedMixedStrides());
+ auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, loc, meta.getBaseBuffer());
+ auto baseAddrI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), baseAddrIndex.getResult());
+ // Strided metadata only provides 1D offset but create_nd_desc op expect
+ // offset match the rank of source memref. Add leading zeros if rank > 1.
+ srcTy.getRank();
+ SmallVector<OpFoldResult> fullOffsets;
+ for (unsigned i = 0; i < srcTy.getRank() - 1; ++i) {
+ fullOffsets.push_back(rewriter.getI64IntegerAttr(0));
+ }
+ fullOffsets.push_back(meta.getConstifiedMixedOffset());
+ ndDesc = xegpu::CreateNdDescOp::create(
+ rewriter, loc, descType, baseAddrI64, fullOffsets,
+ meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
}
return ndDesc;
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index ae5141db16c09..867d1f20fb707 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -10,9 +10,13 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[COLLAPSED]]
-// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// CHECK: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME: : memref<f32> -> index
+// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][%[[OFFSET1]]], shape : [32],
+// CHECK-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
// CHECK: return %[[VEC]]
@@ -30,9 +34,12 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[COLLAPSED]]
-// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME: : memref<f32> -> index
+// CHECK: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
+// CHECK-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
@@ -49,8 +56,11 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// CHECK-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 1a10d917623cc..09bd571951a6b 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -12,9 +12,13 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[COLLAPSED]]
-// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// CHECK: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME: : memref<f32> -> index
+// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][%[[OFFSET1]]], shape : [32],
+// CHECK-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
@@ -32,9 +36,12 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[COLLAPSED]]
-// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME: : memref<f32> -> index
+// CHECK: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
+// CHECK-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// -----
@@ -51,8 +58,11 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// CHECK-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
// -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 8bb272b1fe5fc..af330dced143e 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -49,9 +49,12 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
// LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[COLLAPSED]]
-// LOAD-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND-SAME: : memref<f32> -> index
+// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
+// LOAD-ND-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
// LOAD-ND-SAME: boundary_check = false
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
@@ -148,8 +151,12 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// LOAD-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// LOAD-ND-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND-SAME: #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
@@ -185,7 +192,11 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x8x16xf32>,
// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32, strided<[16, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// LOAD-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [8, 16], strides : [16, 1] :
+// LOAD-ND-SAME: i64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
// LOAD-ND: return %[[VEC]] : vector<8x16xf32>
@@ -460,10 +471,11 @@ gpu.func @load_from_subview_2D(%source: memref<4096x4096xf16>, %off1: index, %of
// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[SUBVIEW]]
-// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf16,
-// LOAD-ND-SAME: boundary_check = false
+// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[SUBVIEW]]
+// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// LOAD-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [256, 256], strides : [4096, 1] :
+// LOAD-ND-SAME: i64 -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8x16xf16>
// LOAD-ND: return %[[VEC]]
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index 43a1a7206e2cc..6185f8537d8e0 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -16,9 +16,13 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[COLLAPSED]]
-// STORE-ND-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// STORE-ND: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// STORE-ND: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND-SAME: : memref<f32> -> index
+// STORE-ND: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][%[[OFFSET1]]], shape : [32],
+// STORE-ND-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32,
// STORE-ND-SAME: boundary_check = false
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
@@ -51,9 +55,12 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[COLLAPSED]]
-// STORE-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
+// STORE-ND: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND-SAME: : memref<f32> -> index
+// STORE-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFF1]]], shape : [16, 32],
+// STORE-ND-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
// STORE-ND-SAME: boundary_check = false
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
@@ -87,8 +94,11 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// STORE-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// STORE-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// STORE-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// STORE-ND: %[[I64PTR:.+]] = arith.index_cast %[[INTPTR]] : index to i64
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]][0, %[[OFFSET]]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// STORE-ND-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
// STORE-SCATTER-LABEL: @store_dynamic_source(
@@ -295,10 +305,11 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
// STORE-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
-// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[COLLAPSED]]
-// STORE-ND-SAME: memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
-// STORE-ND-SAME: boundary_check = false
+// STORE-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND: %[[I64PTR:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+// STORE-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %0[%[[OFFSET]]], shape : [256], strides : [1] : i64 ->
+// STORE-ND-SAME: !xegpu.tensor_desc<8xf16, #xegpu.block_tdesc_...
[truncated]
|
charithaintc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
| auto baseAddrI64 = arith::IndexCastOp::create( | ||
| rewriter, loc, rewriter.getI64Type(), baseAddrIndex.getResult()); | ||
| // Strided metadata only provides 1D offset but create_nd_desc op expect | ||
| // offset match the rank of source memref. Add leading zeros if rank > 1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 1d offset needs to be added to the baseAddrI64 here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated PR to generate adjusted base addr: base addr + offset * element_size_in_bytes
| bool isStatic = true; | ||
|
|
||
| // Memref is dynamic if any of its shape, offset or strides is dynamic. | ||
| if (!srcTy.hasStaticShape()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: braces can be skipped
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed braces.
… runtime. SPIR-V kernel lowering no longer supports i1 store. Fuse GPU kernels to remove usage of i1 stores.
…evelZero runtime." This reverts commit f43d5c8.
If source strided memref is not fully static - at least one of shape, strides, offset is kDynamic - use i64 source variant.
With this change, xegpu.create_nd_tdesc created by lowering from vector dialect, can rely on getMixedOffsets, getMixedSize and getMixedStrides to get relevant values.