Skip to content

Conversation

@charithaintc
Copy link
Contributor

Current transpose optimization in xegpu-optimize-block-loads pass require statically shaped memrefs. This PR adds support for all memref types.

@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Charitha Saumya (charithaintc)

Changes

Current transpose optimization in xegpu-optimize-block-loads pass require statically shaped memrefs. This PR adds support for all memref types.


Full diff: https://github.com/llvm/llvm-project/pull/170218.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp (+2-2)
  • (renamed) mlir/test/Dialect/XeGPU/optimize-block-loads.mlir (+29)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
index 4dc5ea4f7bb24..1642e9829a79d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp
@@ -282,9 +282,9 @@ class XeGPUCreateNdDescOpPattern final
                        modifiedStrides[modifiedStrides.size() - 2]),
         innerLaneData);
 
-    // If the source is a static memref, we need to extract the pointer to
+    // If the source is a memref, we need to extract the pointer to
     // base address.
-    if (memrefType && memrefType.hasStaticShape()) {
+    if (memrefType) {
       auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create(
           rewriter, createNdOp.getLoc(), source);
       source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(),
diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-block-loads.mlir
similarity index 91%
rename from mlir/test/Dialect/XeGPU/optimize-transpose.mlir
rename to mlir/test/Dialect/XeGPU/optimize-block-loads.mlir
index 24a0de6ed48a5..6eaa82f42d02c 100644
--- a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir
+++ b/mlir/test/Dialect/XeGPU/optimize-block-loads.mlir
@@ -278,3 +278,32 @@ gpu.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg
   gpu.return
 }
 }
+
+// -----
+// CHECK-LABEL: gpu.func @dynamic_memref(
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf16>, %{{.*}}: vector<8x16xf16>) -> vector<8x16xf32> {
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 16 : index
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
+// CHECK-NEXT:    %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?xf16> -> index
+// CHECK-NEXT:    %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64
+// CHECK-NEXT:    %[[T1:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C32]]], strides : [%[[C32]], 1] : i64
+// CHECK-SAME:      -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+// CHECK-NEXT:    %[[T2:.*]] = xegpu.load_nd %[[T1]][%{{.*}}, %[[C16]]]  {layout_result_0 =
+// CHECK-SAME:      #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x8xi32,
+// CHECK-SAME:      #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x8xi32>
+// CHECK-NEXT:    %{{.*}} = vector.bitcast %[[T2]] {layout_result_0 =
+// CHECK-SAME:      #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xi32> to vector<16x16xf16>
+#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+gpu.module @xevm_module {
+gpu.func @dynamic_memref(%arg0: memref<?x?xf16>, %arg1: vector<8x16xf16>) -> vector<8x16xf32> {
+  %c0 = arith.constant 0 : index
+  %c32 = arith.constant 32 : index
+  %0 = xegpu.create_nd_tdesc %arg0, shape : [64, 64], strides : [64, 1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
+  %1 = xegpu.load_nd %0[%c0, %c32] { result_layout = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16>
+  %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16>
+  %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #a } : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  gpu.return %6 : vector<8x16xf32>
+}
+}

@github-actions
Copy link

github-actions bot commented Dec 1, 2025

🐧 Linux x64 Test Results

  • 7171 tests passed
  • 595 tests skipped

✅ The build succeeded and all tests passed.

@akroviakov
Copy link
Contributor

akroviakov commented Dec 2, 2025

I feel like extracting a pointer from memref needs to resemble what xegpu-to-xevm has in its memref materialization.
Yes, right now they are the same, but AFAIK, @Jianhui-Li plans to expand it by accounting for metadata (ExtractStridedMetadataOp) so that memrefs coming from reinterpret casts also work correctly.

Perhaps the upcoming materialization logic can be moved to utils for reusability across passes.
Otherwise looks good.

@charithaintc
Copy link
Contributor Author

I feel like extracting a pointer from memref needs to resemble what xegpu-to-xevm has in its memref materialization. Yes, right now they are the same, but AFAIK, @Jianhui-Li plans to expand it by accounting for metadata (ExtractStridedMetadataOp) so that memrefs coming from reinterpret casts also work correctly.

Perhaps the upcoming materialization logic can be moved to utils for reusability across passes. Otherwise looks good.

good point! in this PR, I transfer the existing strides and shape info to the new create_nd. for offset I guess extract_aligned_ptr should handle it? If not we need to revisit this logic.

@akroviakov
Copy link
Contributor

for offset I guess extract_aligned_ptr should handle it?

The lowering of ExtractAlignedPointerAsIndexOp extracts a value from llvm struct. The ReinterpretCastOp lowering sets values. I do not see the application of offsets in the lowering to llvm, do I miss something?

@charithaintc
Copy link
Contributor Author

for offset I guess extract_aligned_ptr should handle it?

The lowering of ExtractAlignedPointerAsIndexOp extracts a value from llvm struct. The ReinterpretCastOp lowering sets values. I do not see the application of offsets in the lowering to llvm, do I miss something?

I see. In that case it would not work then.

gpu.func @dynamic_memref(%arg0: memref<?x?xf16>, %arg1: vector<8x16xf16>) -> vector<8x16xf32> {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%0 = xegpu.create_nd_tdesc %arg0, shape : [64, 64], strides : [64, 1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16, #b>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think strides/shapes parameter makes sense here for dynamic shaped memref, since they need to be the same as the strides/shapes coming from the memref which is usually unknown.
They are typically used when the source is not memref but a pointer as a lowest IR form.

Copy link
Contributor Author

@charithaintc charithaintc Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just a test case. in real usecases, these will be kernel arguments meaning they are runtime values.

They are typically used when the source is not memref but a pointer as a lowest IR form.

not really. this form of memref is really useful when we need to test with dynamic batch sizes that are changed from host side. This PR is needed exactly for that btw, in my FA testing I use small batch sizes but for perf runs I use large ones. Dynamic memrefs allow me to do this very easily.
I think you are referring to unranked memref case (memref<*xf16>). interestingly xegpu op definition does not allow unranked memref for some reason.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean you are using memref<?x64xf16>?
I think strides should go with memref, not provided by user.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean you are using memref<?x64xf16>?

No. memref is fully dynamic. user can send both shapes and strides as kernel args. we trust the user send correct values.

I think strides should go with memref, not provided by user.

Why is this restriction?

In any case, this discussion is irrelevant to this PR. This PR is all about supporting dynamic memrefs in transpose opt. Only thing to check here is whether we adjust the shape and strides to match the i32 dtype.

assert(strides.size() >= 2 &&
"Expected at least 2 strides for CreateNdDescOp");
SmallVector<OpFoldResult> modifiedStrides(strides);
modifiedStrides[modifiedStrides.size() - 2] = divideByConstant(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/llvm/llvm-project/pull/170384/files#r2586030379 is adding strides support. Not sure the modification of strides here is correct. Why stride is divided by some constant here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants