-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][XeGPU][XeVM] create_nd_tdesc: use correct pitch from strides. #170384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Sang Ik Lee (silee2) ChangesBase memory pitch should be derived from base stride, not base width. Patch is 22.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170384.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 7f1ec17ce0ae8..9c99a24bea8cd 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -50,11 +50,10 @@ static constexpr int32_t executionSize{16};
// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
enum class NdTdescOffset : uint32_t {
- BasePtr = 0, // Base pointer (i64)
- BaseShapeW = 2, // Base shape width (i32)
- BaseShapeH = 3, // Base shape height (i32)
- TensorOffsetW = 4, // Tensor offset W (i32)
- TensorOffsetH = 5 // Tensor offset H (i32)
+ BasePtr = 0, // Base pointer (i64)
+ BaseShapeW = 2, // Base shape width (i32)
+ BaseShapeH = 3, // Base shape height (i32)
+ BasePitch = 4, // Base pitch (i32)
};
static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
@@ -179,11 +178,10 @@ class CreateNdDescToXeVMPattern
Value baseAddr;
Value baseShapeW;
Value baseShapeH;
- Value offsetW;
- Value offsetH;
// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+ SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
// Descriptor shape is expected to be 2D.
int64_t rank = mixedSizes.size();
auto sourceTy = source.getType();
@@ -216,12 +214,11 @@ class CreateNdDescToXeVMPattern
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
return val;
};
- // Offsets are not supported (0 is used).
- offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
- offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
// Get shape values from op fold results.
baseShapeW = createOffset(mixedSizes, 1);
baseShapeH = createOffset(mixedSizes, 0);
+ // Get pitch value from op fold results.
+ Value basePitch = createOffset(mixedStrides, 0);
// Populate payload.
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
@@ -235,12 +232,9 @@ class CreateNdDescToXeVMPattern
payload =
vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
static_cast<int>(NdTdescOffset::BaseShapeH));
- payload = vector::InsertOp::create(
- rewriter, loc, offsetW, payload,
- static_cast<int>(NdTdescOffset::TensorOffsetW));
- payload = vector::InsertOp::create(
- rewriter, loc, offsetH, payload,
- static_cast<int>(NdTdescOffset::TensorOffsetH));
+ payload =
+ vector::InsertOp::create(rewriter, loc, basePitch, payload,
+ static_cast<int>(NdTdescOffset::BasePitch));
rewriter.replaceOp(op, payload);
return success();
}
@@ -289,6 +283,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
Value baseShapeH = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
+ Value basePitch = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
// Offsets are provided by the op.
// convert them to i32.
Value offsetW =
@@ -303,8 +299,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
// Compute width in bytes.
- Value surfaceW =
+ Value baseWidthByte =
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+ // Compute pitch in bytes.
+ Value basePitchByte =
+ arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
// Get tile width from the tensor descriptor type.
auto tileW = tdescTy.getDimSize(tileRank - 1);
@@ -331,8 +330,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
xevm::BlockStore2dOp::create(
- rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
- offsetH, elemBitSize, tileW, tileH, src,
+ rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
+ basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, src,
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
rewriter.eraseOp(op);
} else {
@@ -340,9 +339,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
xevm::BlockPrefetch2dOp::create(
- rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW,
- offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
- xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
+ basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
+ vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
rewriter.eraseOp(op);
} else {
VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
@@ -355,9 +354,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
: rewriter.getIntegerType(elemBitSize));
Value resultFlatVec = xevm::BlockLoad2dOp::create(
- rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
- surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
- transpose, vnni,
+ rewriter, loc, loadedTy, basePtrLLVM, baseWidthByte, baseShapeH,
+ basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
+ vblocks, transpose, vnni,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
resultFlatVec = vector::BitCastOp::create(
rewriter, loc,
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 8b87b791c9fd3..9a1e2cb3c7de0 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -8,21 +8,19 @@ gpu.module @create_nd_tdesc {
gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
%stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
// CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
- // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
+ // CHECK: %[[DYN_ADDR:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32
- // CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32
// CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32
// CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32
+ // CHECK: %[[PITCH:.*]] = arith.index_cast %[[ARG4]] : index to i32
// CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32>
- // CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32>
- // CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32>
+ // CHECK: %[[VAR11:.*]] = vector.insert %[[PITCH]], %[[VAR10]] [4] : i32 into vector<8xi32>
%ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
: ui64 -> !xegpu.tensor_desc<8x16xf32>
@@ -32,19 +30,18 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
- // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
// CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
// CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
+ // CHECK: %[[C32_I64_2:.*]] = arith.constant 32 : i64
+ // CHECK: %[[PITCH2:.*]] = arith.trunci %[[C32_I64_2]] : i64 to i32
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
- // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
- // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
+ // CHECK: %[[VAR19:.*]] = vector.insert %[[PITCH2]], %[[VAR18]] [4] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
@@ -53,18 +50,16 @@ gpu.module @create_nd_tdesc {
%size_x = arith.constant 64 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index
%BLOCK_DMODEL = arith.constant 16 : index
- // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
- // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
- // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
- // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
- // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
- // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
- // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
- // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
- // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>
- // CHECK: %[[VAR29:.*]] = vector.insert %[[C0_I32_6]], %[[VAR28]] [4] : i32 into vector<8xi32>
- // CHECK: %[[VAR30:.*]] = vector.insert %[[C0_I32_7]], %[[VAR29]] [5] : i32 into vector<8xi32>
+ // CHECK: %[[CST_3:.*]] = arith.constant dense<0> : vector<8xi32>
+ // CHECK: %[[SHAPE_W3:.*]] = arith.index_cast %[[C16]] : index to i32
+ // CHECK: %[[SHAPE_H3:.*]] = arith.index_cast %[[C64]] : index to i32
+ // CHECK: %[[PITCH3:.*]] = arith.index_cast %[[C16]] : index to i32
+ // CHECK: %[[VAR25:.*]] = vector.bitcast %[[CST_3]] : vector<8xi32> to vector<4xi64>
+ // CHECK: %[[VAR26:.*]] = vector.insert %[[DYN_ADDR]], %[[VAR25]] [0] : i64 into vector<4xi64>
+ // CHECK: %[[VAR27:.*]] = vector.bitcast %[[VAR26]] : vector<4xi64> to vector<8xi32>
+ // CHECK: %[[VAR28:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR27]] [2] : i32 into vector<8xi32>
+ // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR28]] [3] : i32 into vector<8xi32>
+ // CHECK: %[[VAR30:.*]] = vector.insert %[[PITCH3]], %[[VAR29]] [4] : i32 into vector<8xi32>
%dyn_tdesc = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16>
gpu.return
}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
index afeae8be24b72..4c73c9c238b6e 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -1,78 +1,32 @@
-// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
gpu.module @load_store_check {
// CHECK-LABEL: gpu.func @load_store(
- // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: memref<8x16xf32, 1>) kernel {
gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+ // CHECK: %[[W_P_BYTES:.*]] = arith.constant 64 : i32
+ // CHECK: %[[ZERO:.*]] = arith.constant 0 : i32
+ // CHECK: %[[H:.*]] = arith.constant 8 : i32
%srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
%dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
- // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
- // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
- // CHECK: %[[MEMSPACECAST_0:.*]] = memref.memory_space_cast %[[ARG1]] : memref<8x16xf32, 1> to memref<8x16xf32>
- // CHECK: %[[INTPTR_1:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST_0]] : memref<8x16xf32> -> index
- // CHECK: %[[ST_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR_1]] : index to i64
- // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
- // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
- // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
- // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
- // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
- // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
- // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-
- //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
- //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
- //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
- //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
- //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
- //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64
- //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32
- //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
- //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
- //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
- //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
- //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
- //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
+ //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]]
//CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
//CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
//CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
%loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
- //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32>
%tid_x = gpu.thread_id x
%tid_x_i32 = arith.index_cast %tid_x : index to i32
%tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
- //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
%loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
- // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
- // CHECK: %[[DESC_0:.*]] = vector.insert %[[ST_PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
- // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
- // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
- // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>
- // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32>
- // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
%dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
- //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
- //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
- //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
- //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
- //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
- //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64
- //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32
- //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
- //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
- //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
- //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
- //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
- //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],
- //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]]
- //CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+ //CHECK: xevm.blockstore2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]], %{{.*}} <{
+ //CHECK-SAME: cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
//CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
: vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
diff --git a/mlir/test/Conve...
[truncated]
|
|
|
||
| // Source can be a memref or a pointer (ui64, ui32, i64 or i32). | ||
| SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes(); | ||
| SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should differentiate the source being a memref or a pointer. For pointer, user is expected to provide both shapes and strides, so the above code works fine.
But for memref source, user may not know the stride, the code should extract the strides from memref. For dynamic shape memref, this will trigger the ExtractStridedMetadataOp again (after the one in the type conversion to get base addr and offset) but I guess it should be removed in the llvm level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if ranked dynamic memref triggers ExtractStrideMetadataOp, lowering will clean up and allow direct access to relevant fields from lowered and decomposed memref.
See
(python-3.9) jovyan@jupyter-silee2:~/Projects/llvm-project [main|⚑ 29]$ cat strided.mlir
module {
func.func @test(%arg0: memref<?x?xf32>) -> (index) {
%base, %offset, %sizes:2, %strides:2 =
memref.extract_strided_metadata %arg0 : memref<?x?xf32>
-> memref<f32>, index, index, index, index, index
return %strides#0 : index
}
}
(python-3.9) jovyan@jupyter-silee2:~/Projects/llvm-project [main|⚑ 29]$ ./build/bin/mlir-opt --convert-to-llvm -canonicalize strided.mlir
module {
llvm.func @test(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) -> i64 {
llvm.return %arg5 : i64
}
}
You can see that stride is forwarded directly from kernel arg, which is lowered and unpacked from memref.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I am not mistaken for dynamic memrefs this still get the strides from createNd's own parameters, not memref's right?
Is that done in a seperate PR?
@Jianhui-Li Seems like this PR also still get the strides from CreateNd? So in that case is it fine to move ahead with #170218? And we can fix the whole thing (take strides using ExtractMetaOp) and remove shape, strided from createNd in a new PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. The PR does not address the issue of Dynamic memrefs. Dynamic memref will be handled in a separate PR.
Jianhui-Li
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Jianhui-Li
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Base memory pitch should be derived from base stride, not base width.
Remove offset fields from tensor descriptor payload and add pitch field.