Skip to content

Conversation

@silee2
Copy link
Contributor

@silee2 silee2 commented Oct 20, 2025

Current lowering pattern for create_nd_tdesc restricts source memref to static shape.
In case of a dynamic ranked memref, create_nd_tdesc already provides shape as an argument.
Lowering can use those values instead of returning a mismatch error.

create_nd_tdesc source memref just needs to a ranked memref.
@llvmbot
Copy link
Member

llvmbot commented Oct 20, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Sang Ik Lee (silee2)

Changes

create_nd_tdesc source memref just needs to a ranked memref.


Full diff: https://github.com/llvm/llvm-project/pull/164283.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+2-2)
  • (modified) mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir (+24-1)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index fcbf66dbe9e45..33e8f2ed1f6ed 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -194,8 +194,8 @@ class CreateNdDescToXeVMPattern
     // If source is a memref, we need to extract the aligned pointer as index.
     // Pointer type is passed as i32 or i64 by type converter.
     if (sourceMemrefTy) {
-      if (!sourceMemrefTy.hasStaticShape()) {
-        return rewriter.notifyMatchFailure(op, "Expected static memref shape.");
+      if (!sourceMemrefTy.hasRank()) {
+        return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
       }
       baseAddr =
           memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index d6e36fa73bf04..09ef76c9d1740 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -4,8 +4,9 @@ gpu.module @create_nd_tdesc {
   // CHECK-LABEL: gpu.func @create_nd_tdesc
   // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64,
   // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
+  // CHECK-SAME: %[[DYN:.*]]: memref<?x?xf16>) kernel {
   gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
-  %stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel {
+  %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
         // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
         // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
         // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
@@ -43,6 +44,28 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
         // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+        // CHECK: %[[C1:.*]] = arith.constant 1 : index
+        %c1 = arith.constant 1 : index
+        // CHECK: %[[C64:.*]] = arith.constant 64 : index
+        %size_x = arith.constant 64 : index
+        // CHECK: %[[C16:.*]] = arith.constant 16 : index
+        %BLOCK_DMODEL = arith.constant 16 : index
+        // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
+        // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
+        // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
+        // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
+        // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
+        // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
+        // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
+        // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[VAR29:.*]] = vector.insert %[[C0_I32_6]], %[[VAR28]] [4] : i32 into vector<8xi32>
+        // CHECK: %[[VAR30:.*]] = vector.insert %[[C0_I32_7]], %[[VAR29]] [5] : i32 into vector<8xi32>
+        %dyn_tdesc  = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16>
         gpu.return
     }
 }

@Garra1980
Copy link

Can you please update the PR title specifying briefly what exactly is being relaxed

@silee2 silee2 changed the title [MLIR][Conversion] XeGPU to XeVM: Relax create_nd_tdesc restriction. [MLIR][Conversion] XeGPU to XeVM: Lower ranked dynamic base memory for create_nd_tdesc. Oct 23, 2025
@silee2
Copy link
Contributor Author

silee2 commented Oct 23, 2025

Can you please update the PR title specifying briefly what exactly is being relaxed

Changed title to reflect actual change.

Copy link
Contributor

@nbpatel nbpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@silee2 silee2 merged commit b258f5c into llvm:main Oct 29, 2025
10 checks passed
aokblast pushed a commit to aokblast/llvm-project that referenced this pull request Oct 30, 2025
…r create_nd_tdesc. (llvm#164283)

Current lowering pattern for create_nd_tdesc restricts source memref to
static shape.
In case of a dynamic ranked memref, create_nd_tdesc already provides
shape as an argument.
Lowering can use those values instead of returning a mismatch error.
DEBADRIBASAK pushed a commit to DEBADRIBASAK/llvm-project that referenced this pull request Nov 3, 2025
…r create_nd_tdesc. (llvm#164283)

Current lowering pattern for create_nd_tdesc restricts source memref to
static shape.
In case of a dynamic ranked memref, create_nd_tdesc already provides
shape as an argument.
Lowering can use those values instead of returning a mismatch error.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants