-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR][Conversion] XeGPU to XeVM: Lower ranked dynamic base memory for create_nd_tdesc. #164283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
create_nd_tdesc source memref just needs to a ranked memref.
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Sang Ik Lee (silee2) Changescreate_nd_tdesc source memref just needs to a ranked memref. Full diff: https://github.com/llvm/llvm-project/pull/164283.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index fcbf66dbe9e45..33e8f2ed1f6ed 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -194,8 +194,8 @@ class CreateNdDescToXeVMPattern
// If source is a memref, we need to extract the aligned pointer as index.
// Pointer type is passed as i32 or i64 by type converter.
if (sourceMemrefTy) {
- if (!sourceMemrefTy.hasStaticShape()) {
- return rewriter.notifyMatchFailure(op, "Expected static memref shape.");
+ if (!sourceMemrefTy.hasRank()) {
+ return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
}
baseAddr =
memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index d6e36fa73bf04..09ef76c9d1740 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -4,8 +4,9 @@ gpu.module @create_nd_tdesc {
// CHECK-LABEL: gpu.func @create_nd_tdesc
// CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64,
// CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
+ // CHECK-SAME: %[[DYN:.*]]: memref<?x?xf16>) kernel {
gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
- %stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel {
+ %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
@@ -43,6 +44,28 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
// CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[C64:.*]] = arith.constant 64 : index
+ %size_x = arith.constant 64 : index
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
+ %BLOCK_DMODEL = arith.constant 16 : index
+ // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
+ // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
+ // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
+ // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
+ // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
+ // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
+ // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[VAR29:.*]] = vector.insert %[[C0_I32_6]], %[[VAR28]] [4] : i32 into vector<8xi32>
+ // CHECK: %[[VAR30:.*]] = vector.insert %[[C0_I32_7]], %[[VAR29]] [5] : i32 into vector<8xi32>
+ %dyn_tdesc = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16>
gpu.return
}
}
|
|
Can you please update the PR title specifying briefly what exactly is being relaxed |
Changed title to reflect actual change. |
nbpatel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…r create_nd_tdesc. (llvm#164283) Current lowering pattern for create_nd_tdesc restricts source memref to static shape. In case of a dynamic ranked memref, create_nd_tdesc already provides shape as an argument. Lowering can use those values instead of returning a mismatch error.
…r create_nd_tdesc. (llvm#164283) Current lowering pattern for create_nd_tdesc restricts source memref to static shape. In case of a dynamic ranked memref, create_nd_tdesc already provides shape as an argument. Lowering can use those values instead of returning a mismatch error.
Current lowering pattern for create_nd_tdesc restricts source memref to static shape.
In case of a dynamic ranked memref, create_nd_tdesc already provides shape as an argument.
Lowering can use those values instead of returning a mismatch error.