diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index fcbf66dbe9e45..cee051ab4dd7d 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -48,15 +48,6 @@ namespace { static constexpr int32_t systolicDepth{8}; static constexpr int32_t executionSize{16}; -// Offsets to individual fields of the 8xi32 layout nd tensor descriptor. -enum class NdTdescOffset : uint32_t { - BasePtr = 0, // Base pointer (i64) - BaseShapeW = 2, // Base shape width (i32) - BaseShapeH = 3, // Base shape height (i32) - TensorOffsetW = 4, // Tensor offset W (i32) - TensorOffsetH = 5 // Tensor offset H (i32) -}; - static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { switch (xeGpuMemspace) { case xegpu::MemorySpace::Global: @@ -151,6 +142,22 @@ translateStoreXeGPUCacheHint(std::optional L1hint, } } +// Compute the product of sizes in the range [lo, hi) from the sizes array. +static Value getProductOfSizes(ConversionPatternRewriter &rewriter, + Location loc, ArrayRef sizes, + size_t lo, size_t hi) { + Value product = + arith::ConstantIntOp::create(rewriter, loc, rewriter.getI64Type(), 1); + for (size_t idx = lo; idx < hi; idx++) { + OpFoldResult ofr = sizes[idx]; + Value sizeVal = getValueOrCreateConstantIntOp(rewriter, loc, ofr); + sizeVal = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI64Type(), sizeVal); + product = rewriter.createOrFold(loc, product, sizeVal); + } + return product; +} + class CreateNdDescToXeVMPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -162,86 +169,14 @@ class CreateNdDescToXeVMPattern if (mixedOffsets.size() != 0) return rewriter.notifyMatchFailure(op, "Offsets not supported."); auto loc = op.getLoc(); - auto source = op.getSource(); - // Op is lowered to a code sequence that populates payload. - // Payload is a 8xi32 vector. Offset to individual fields are defined in - // NdTdescOffset enum. - Type payloadElemTy = rewriter.getI32Type(); - VectorType payloadTy = VectorType::get(8, payloadElemTy); - Type i64Ty = rewriter.getI64Type(); - // 4xi64 view is used for inserting the base pointer. - VectorType payloadI64Ty = VectorType::get(4, i64Ty); - // Initialize payload to zero. - Value payload = arith::ConstantOp::create( - rewriter, loc, - DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0))); - - Value baseAddr; - Value baseShapeW; - Value baseShapeH; - Value offsetW; - Value offsetH; - // Source can be a memref or a pointer (ui64, ui32, i64 or i32). - SmallVector mixedSizes = op.getMixedSizes(); - // Descriptor shape is expected to be 2D. - int64_t rank = mixedSizes.size(); - if (rank != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D shape."); - - auto sourceTy = source.getType(); - auto sourceMemrefTy = dyn_cast(sourceTy); - // If source is a memref, we need to extract the aligned pointer as index. - // Pointer type is passed as i32 or i64 by type converter. - if (sourceMemrefTy) { - if (!sourceMemrefTy.hasStaticShape()) { - return rewriter.notifyMatchFailure(op, "Expected static memref shape."); - } - baseAddr = - memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); - } else { - baseAddr = adaptor.getSource(); - } - // Utility for creating offset values from op fold result. - auto createOffset = [&](SmallVector &ofrVec, - unsigned idx) -> Value { - Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]); - val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val); - return val; - }; - // Offsets are not supported (0 is used). - offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); - offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); - // Get shape values from op fold results. - baseShapeW = createOffset(mixedSizes, 1); - baseShapeH = createOffset(mixedSizes, 0); - if (sourceMemrefTy) { - // Cast index to i64. - baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); - } else if (baseAddr.getType() != i64Ty) { + Value baseAddr = adaptor.getSource(); + Type i64Ty = rewriter.getI64Type(); + if (baseAddr.getType() != i64Ty) { // Pointer type may be i32. Cast to i64 if needed. baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); } - // Populate payload. - Value payLoadAsI64 = - vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload); - payLoadAsI64 = - vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64, - static_cast(NdTdescOffset::BasePtr)); - payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64); - payload = - vector::InsertOp::create(rewriter, loc, baseShapeW, payload, - static_cast(NdTdescOffset::BaseShapeW)); - payload = - vector::InsertOp::create(rewriter, loc, baseShapeH, payload, - static_cast(NdTdescOffset::BaseShapeH)); - payload = vector::InsertOp::create( - rewriter, loc, offsetW, payload, - static_cast(NdTdescOffset::TensorOffsetW)); - payload = vector::InsertOp::create( - rewriter, loc, offsetH, payload, - static_cast(NdTdescOffset::TensorOffsetH)); - rewriter.replaceOp(op, payload); + rewriter.replaceOp(op, baseAddr); return success(); } }; @@ -255,14 +190,24 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { LogicalResult matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + auto tdVal = op.getTensorDesc(); + xegpu::CreateNdDescOp descOp = + tdVal.template getDefiningOp(); + if (!descOp) + return rewriter.notifyMatchFailure( + op, "Expected tensor descriptor to be created by CreateNdDescOp."); + auto mixedStrides = descOp.getMixedStrides(); auto mixedOffsets = op.getMixedOffsets(); - int64_t opOffsetsSize = mixedOffsets.size(); - if (opOffsetsSize != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); + auto mixedSizes = descOp.getMixedSizes(); + size_t opOffsetsSize = mixedOffsets.size(); + if (opOffsetsSize != mixedStrides.size()) + return rewriter.notifyMatchFailure( + op, "Offsets size should match base memory rank."); + if (opOffsetsSize < 2) + return rewriter.notifyMatchFailure(op, "Expected at least 2D offset."); auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); - auto tdesc = adaptor.getTensorDesc(); auto tdescTy = op.getTensorDescType(); if (tdescTy.getRank() != 2) return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor."); @@ -272,23 +217,58 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "Expected element type bit width to be multiple of 8."); - VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); - Value payLoadAsI64 = - vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); - Value basePtr = vector::ExtractOp::create( - rewriter, loc, payLoadAsI64, static_cast(NdTdescOffset::BasePtr)); - Value baseShapeW = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast(NdTdescOffset::BaseShapeW)); - Value baseShapeH = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast(NdTdescOffset::BaseShapeH)); + Value basePtr = adaptor.getTensorDesc(); + // Utility for creating offset values from op fold result. + Type payloadElemTy = rewriter.getIntegerType(32); + auto createOffset = [&](OpFoldResult ofr) -> Value { + Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofr); + val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val); + return val; + }; + auto srcRank = mixedSizes.size(); + // Get shape values from op fold results. + Value baseShapeW = createOffset(mixedSizes[srcRank - 1]); + Value baseShapeH; + if (srcRank == 2) { + baseShapeH = createOffset(mixedSizes[0]); + } else { + // Generate compute chain for height (product of sizes of all but the last + // dimension). + baseShapeH = getProductOfSizes(rewriter, loc, mixedSizes, 0, srcRank - 1); + baseShapeH = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, + baseShapeH); + } // Offsets are provided by the op. // convert them to i32. - Value offsetW = - getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); + // Offset computation assumes base memory layout is row major. + Value offsetW = getValueOrCreateConstantIntOp( + rewriter, loc, mixedOffsets[opOffsetsSize - 1]); offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI32Type(), offsetW); - Value offsetH = - getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + Value offsetH; + if (opOffsetsSize == 2) + offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + else { + offsetH = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value tmpStride = arith::ConstantIndexOp::create(rewriter, loc, 1); + // offsetH requires computing the linear offset using the strides. + for (size_t idx = 0; idx < opOffsetsSize - 1; idx++) { + size_t revIdx = opOffsetsSize - 2 - idx; + Value offsetVal = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[revIdx]); + offsetVal = getValueOrCreateCastToIndexLike( + rewriter, loc, rewriter.getIndexType(), offsetVal); + Value mul = + rewriter.createOrFold(loc, tmpStride, offsetVal); + Value dimSize = + getValueOrCreateConstantIntOp(rewriter, loc, mixedSizes[revIdx]); + dimSize = getValueOrCreateCastToIndexLike( + rewriter, loc, rewriter.getIndexType(), dimSize); + tmpStride = + rewriter.createOrFold(loc, tmpStride, dimSize); + offsetH = rewriter.createOrFold(loc, offsetH, mul); + } + } offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI32Type(), offsetH); // Get address space from tensor descriptor memory space. @@ -927,10 +907,7 @@ struct ConvertXeGPUToXeVMPass return VectorType::get(sum, elemType); }); typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type { - if (type.isScattered()) - return IntegerType::get(&getContext(), 64); - auto i32Type = IntegerType::get(&getContext(), 32); - return VectorType::get(8, i32Type); + return IntegerType::get(&getContext(), 64); }); // Convert MemDescType into flattened MemRefType for SLM typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir index d6e36fa73bf04..38d2c6483c204 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir @@ -4,45 +4,25 @@ gpu.module @create_nd_tdesc { // CHECK-LABEL: gpu.func @create_nd_tdesc // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64, // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index + // CHECK-SAME: %[[ARG8:.*]]: memref) kernel { gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index, - %stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel { - // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index - // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> - // CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32 - // CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32 - // CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32 - // CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32 - // CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64> - // CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64> - // CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32> - // CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32> - // CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32> - // CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32> + %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref) kernel { + // Optimized away %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2] : ui64 -> !xegpu.tensor_desc<8x16xf32> - - // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32> + // CHECK-NEXT: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32> %srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32> - - // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32> - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index - // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32 - // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32 - // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64 - // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32 - // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64 - // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32 - // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 - // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64> - // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64> - // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32> - // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32> - // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32> - // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32> + // Optimized away %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-NEXT: %c1 = arith.constant 1 : index + %c1 = arith.constant 1 : index + // CHECK-NEXT: %c64 = arith.constant 64 : index + %size_x = arith.constant 64 : index + // CHECK-NEXT: %c16 = arith.constant 16 : index + %BLOCK_DMODEL = arith.constant 16 : index + // Optimized away + %dyn_tdesc = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref -> !xegpu.tensor_desc<16x16xf16> + // CHECK-NEXT: gpu.return gpu.return } } diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir index 4c6bbf25b4728..0764129cfd447 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir @@ -1,73 +1,45 @@ -// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s gpu.module @load_store_check { + // CHECK-LABEL: gpu.func @load_store + // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: memref<8x16xf32, 1> gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel { + // CHECK: %[[C64_i32:.*]] = arith.constant 64 : i32 + // CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32 + // CHECK: %[[C8_i32:.*]] = arith.constant 8 : i32 + // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST:.*]] : memref<8x16xf32> -> index + // CHECK: %[[VAR0:.*]] = arith.index_castui %[[INTPTR]] : index to i64 %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32> + // CHECK: %[[MEMSPACECAST_0:.*]] = memref.memory_space_cast %[[ARG1]] + // CHECK: %[[INTPTR_1:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST_0]] : memref<8x16xf32> -> index + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[INTPTR_1:.*]] : index to i64 %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32> - // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 - // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> - // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64> - // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32> %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> - - //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64> - //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64> - //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32> - //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32> - //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64 - //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32 - //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64 - //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32 - //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1> - //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32 - //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32 - //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]], - //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]] - //CHECK-SAME: <{cache_control = #xevm.load_cache_control, elem_size_in_bits = 32 : i32, - //CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false, - //CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + // CHECK: %[[VAR2:.*]] = llvm.inttoptr %[[VAR0]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAR3:.*]] = xevm.blockload2d %[[VAR2]], %[[C64_i32]], %[[C8_i32]], %[[C64_i32]], + // CHECK-SAME: %[[C0_i32]], %[[C0_i32]] <{cache_control = #xevm.load_cache_control, + // CHECK-SAME: elem_size_in_bits = 32 : i32, pack_register = false, tile_height = 8 : i32, + // CHECK-SAME: tile_width = 16 : i32, transpose = false, v_blocks = 1 : i32}> + // CHECK: %[[VAR4:.*]] = vector.bitcast %[[VAR3]] : vector<8xi32> to vector<8xf32> %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32> - //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32> %tid_x = gpu.thread_id x %tid_x_i32 = arith.index_cast %tid_x : index to i32 %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32 - //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32> + // CHECK: %[[VAR7:.*]] = vector.insert %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32> - // CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 - // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> - // CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64> - // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32> - // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32> - // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32> - // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32> %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> - //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64> - //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64> - //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32> - //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32> - //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64 - //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32 - //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64 - //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32 - //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1> - //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32 - //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32 - //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32> - //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]], - //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]] - //CHECK-SAME: <{cache_control = #xevm.store_cache_control, elem_size_in_bits = 32 : i32, - //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + // CHECK: %[[VAR8:.*]] = llvm.inttoptr %[[VAR1]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAR9:.*]] = vector.bitcast %[[VAR7]] : vector<8xf32> to vector<8xi32> + // CHECK: xevm.blockstore2d %[[VAR8]], %[[C64_i32]], %[[C8_i32]], %[[C64_i32]], %[[C0_i32]], %[[C0_i32]], %[[VAR9]] + // CHECK-SAME: <{cache_control = #xevm.store_cache_control, elem_size_in_bits = 32 : i32, + // CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> gpu.return diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank.mlir new file mode 100644 index 0000000000000..d80f12c06a58a --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s + +gpu.module @load_store_check { + // CHECK-LABEL: gpu.func @load_store + // CHECK-SAME: %[[ARG0:.*]]: memref<3x3x8x16xf32, 1>, %[[ARG1:.*]]: memref<3x3x8x16xf32, 1>) kernel { + gpu.func @load_store(%src: memref<3x3x8x16xf32, 1>, %dst: memref<3x3x8x16xf32, 1>) kernel { + // CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32 + // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32 + // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32 + // CHECK: %[[C72_I32:.*]] = arith.constant 72 : i32 + // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<3x3x8x16xf32, 1> to memref<3x3x8x16xf32> + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<3x3x8x16xf32> -> index + // CHECK: %[[VAR0:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + %srcce = memref.memory_space_cast %src : memref<3x3x8x16xf32, 1> to memref<3x3x8x16xf32> + // CHECK: %[[MEMSPACECAST_0:.*]] = memref.memory_space_cast %[[ARG1]] : memref<3x3x8x16xf32, 1> to memref<3x3x8x16xf32> + // CHECK: %[[INTPTR_1:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST_0]] : memref<3x3x8x16xf32> -> index + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[INTPTR_1]] : index to i64 + %dstte = memref.memory_space_cast %dst : memref<3x3x8x16xf32, 1> to memref<3x3x8x16xf32> + + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<3x3x8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + + // CHECK: %[[VAR2:.*]] = llvm.inttoptr %[[VAR0]] : i64 to !llvm.ptr<1> + // CHECK: %[[LOADED:.*]] = xevm.blockload2d %[[VAR2]], %[[C64_I32]], %[[C72_I32]], %[[C64_I32]], + // CHECK-SAME: %[[C0_I32]], %[[C64_I32]] <{cache_control = #xevm.load_cache_control, + // CHECK-SAME: elem_size_in_bits = 32 : i32, pack_register = false, tile_height = 8 : i32, + // CHECK-SAME: tile_width = 16 : i32, transpose = false, v_blocks = 1 : i32}> + %loaded = xegpu.load_nd %src_tdesc[2, 2, 0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32> + + %tid_x = gpu.thread_id x + %tid_x_i32 = arith.index_cast %tid_x : index to i32 + %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32 + // CHECK: %[[VAR7:.*]] = vector.insert + %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32> + + %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<3x3x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + + // CHECK: %[[VAR8:.*]] = llvm.inttoptr %[[VAR1]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAR9:.*]] = vector.bitcast %[[VAR7]] : vector<8xf32> to vector<8xi32> + // CHECK: xevm.blockstore2d %[[VAR8]], %[[C64_I32]], %[[C72_I32]], %[[C64_I32]], %[[C0_I32]], %[[C32_I32]], %[[VAR9]] + // CHECK-SAME: <{cache_control = #xevm.store_cache_control, + // CHECK-SAME: elem_size_in_bits = 32 : i32, tile_height = 8 : i32, tile_width = 16 : i32}> + xegpu.store_nd %loaded_modified, %dst_tdesc[1, 1, 0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank_dynamic.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank_dynamic.mlir new file mode 100644 index 0000000000000..16ecd978ad307 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank_dynamic.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s + +gpu.module @load_store_check { + // CHECK-LABEL: gpu.func @load_store + // CHECK-SAME: %[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref) kernel { + gpu.func @load_store(%src: memref, %dst: memref) kernel { + // CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32 + // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32 + // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32 + // CHECK: %[[C72_I32:.*]] = arith.constant 72 : i32 + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref -> index + // CHECK: %[[VAR0:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[INTPTR_0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref -> index + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[INTPTR_0]] : index to i64 + %dim0 = arith.constant 3 : index + %dim1 = arith.constant 3 : index + %dim2 = arith.constant 8 : index + %dim3 = arith.constant 16 : index + %stride3 = arith.constant 1 : index + %stride2 = arith.constant 16 : index + %stride1 = arith.constant 128 : index + %stride0 = arith.constant 384 : index + + %src_tdesc = xegpu.create_nd_tdesc %src, shape:[%dim0, %dim1, %dim2, %dim3], + strides:[%stride0, %stride1, %stride2, %stride3] : memref -> !xegpu.tensor_desc<8x16xf32> + + // CHECK: %[[VAR2:.*]] = llvm.inttoptr %[[VAR1]] : i64 to !llvm.ptr<1> + // CHECK: %[[LOADED:.*]] = xevm.blockload2d %[[VAR2]], %[[C64_I32]], %[[C72_I32]], %[[C64_I32]], + // CHECK-SAME: %[[C0_I32]], %[[C64_I32]] <{cache_control = #xevm.load_cache_control, + // CHECK-SAME: elem_size_in_bits = 32 : i32, pack_register = false, tile_height = 8 : i32, + // CHECK-SAME: tile_width = 16 : i32, transpose = false, v_blocks = 1 : i32}> + // CHECK: %[[LOADED_F32:.*]] = vector.bitcast %[[LOADED]] : vector<8xi32> to vector<8xf32> + %loaded = xegpu.load_nd %src_tdesc[2, 2, 0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32> + + %tid_x = gpu.thread_id x + %tid_x_i32 = arith.index_cast %tid_x : index to i32 + %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32 + // CHECK: %[[LOADED_MODIFIED:.*]] = vector.insert + %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32> + + %dst_tdesc = xegpu.create_nd_tdesc %dst, shape:[%dim0, %dim1, %dim2, %dim3], + strides:[%stride0, %stride1, %stride2, %stride3] : memref -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + + // CHECK: %[[VAR8:.*]] = llvm.inttoptr %[[VAR0]] : i64 to !llvm.ptr<1> + // CHECK: %[[LOADED_MODIFIED_BC:.*]] = vector.bitcast %[[LOADED_MODIFIED]] : vector<8xf32> to vector<8xi32> + // CHECK: xevm.blockstore2d %[[VAR8]], %[[C64_I32]], %[[C72_I32]], %[[C64_I32]], + // CHECK-SAME: %[[C0_I32]], %[[C32_I32]], %[[LOADED_MODIFIED_BC]] <{cache_control = #xevm.store_cache_control, + // CHECK-SAME: elem_size_in_bits = 32 : i32, tile_height = 8 : i32, tile_width = 16 : i32}> + xegpu.store_nd %loaded_modified, %dst_tdesc[1, 1, 0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<8xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank_int_addr.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank_int_addr.mlir new file mode 100644 index 0000000000000..428534c628314 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank_int_addr.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize -cse %s | FileCheck %s + +gpu.module @load_store_check { + // CHECK-LABEL: gpu.func @load_store + // CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64, + // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index, %[[ARG8:.*]]: index, %[[ARG9:.*]]: index + gpu.func @load_store(%src: i64, %dst: i64, %dim0: index, %dim1: index, %dim2: index, %dim3: index, + %stride0: index, %stride1: index, %stride2: index, %stride3: index) kernel { + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32 + // CHECK: %[[C4_I32:.*]] = arith.constant 4 : i32 + // CHECK: %[[VAR0:.*]] = arith.index_cast %[[ARG5]] : index to i32 + // CHECK: %[[VAR1:.*]] = arith.index_cast %[[ARG2]] : index to i64 + // CHECK: %[[VAR2:.*]] = arith.index_cast %[[ARG3]] : index to i64 + // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[VAR2]] : i64 + // CHECK: %[[VAR4:.*]] = arith.index_cast %[[ARG4]] : index to i64 + // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[VAR4]] : i64 + // CHECK: %[[VAR6:.*]] = arith.trunci %[[VAR5]] : i64 to i32 + // CHECK: %[[VAR7:.*]] = arith.muli %[[ARG4]], %[[C2]] : index + // CHECK: %[[VAR8:.*]] = arith.muli %[[ARG4]], %[[ARG3]] : index + // CHECK: %[[VAR9:.*]] = arith.muli %[[VAR8]], %[[C2]] : index + // CHECK: %[[VAR10:.*]] = arith.addi %[[VAR7]], %[[VAR9]] : index + // CHECK: %[[VAR11:.*]] = arith.index_cast %[[VAR10]] : index to i32 + %src_tdesc = xegpu.create_nd_tdesc %src, shape:[%dim0, %dim1, %dim2, %dim3], + strides:[%stride0, %stride1, %stride2, %stride3] : i64 -> !xegpu.tensor_desc<8x16xf32> + + // CHECK: %[[SRC_PTR:.*]] = llvm.inttoptr %[[ARG0]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAR13:.*]] = arith.muli %[[VAR0]], %[[C4_I32]] : i32 + // CHECK: %[[LOADED:.*]] = xevm.blockload2d %[[SRC_PTR]], %[[VAR13]], %[[VAR6]], %[[VAR13]], %[[C0_I32]], %[[VAR11]] <{cache_control = #xevm.load_cache_control, elem_size_in_bits = 32 : i32, pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false, v_blocks = 1 : i32}> + // CHECK: %[[VAR15:.*]] = vector.bitcast %[[LOADED]] : vector<8xi32> to vector<8xf32> + %loaded = xegpu.load_nd %src_tdesc[2, 2, 0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32> + + %tid_x = gpu.thread_id x + %tid_x_i32 = arith.index_cast %tid_x : index to i32 + %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32 + // CHECK: %[[LOADED_MODIFIED:.*]] = vector.insert + %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32> + + // CHECK: %[[VAR19:.*]] = arith.addi %[[ARG4]], %[[VAR8]] : index + // CHECK: %[[VAR20:.*]] = arith.index_cast %[[VAR19]] : index to i32 + %dst_tdesc = xegpu.create_nd_tdesc %dst, shape:[%dim0, %dim1, %dim2, %dim3], + strides:[%stride0, %stride1, %stride2, %stride3] : i64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + + // CHECK: %[[DST_PTR:.*]] = llvm.inttoptr %[[ARG1]] : i64 to !llvm.ptr<1> + // CHECK: %[[LOADED_MODIFIED_BITCAST:.*]] = vector.bitcast %[[LOADED_MODIFIED]] : vector<8xf32> to vector<8xi32> + // CHECK: xevm.blockstore2d %[[DST_PTR]], %[[VAR13]], %[[VAR6]], %[[VAR13]], %[[C0_I32]], %[[VAR20]], %[[LOADED_MODIFIED_BITCAST]] <{cache_control = #xevm.store_cache_control, elem_size_in_bits = 32 : i32, tile_height = 8 : i32, tile_width = 16 : i32}> + xegpu.store_nd %loaded_modified, %dst_tdesc[1, 1, 0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_int_addr.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_int_addr.mlir new file mode 100644 index 0000000000000..c8ce0b3021b3f --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_int_addr.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s + +gpu.module @load_store_check { + // CHECK-LABEL: gpu.func @load_store + // CHECK-SAME: %[[ARG0:.*]]: ui64, %[[ARG1:.*]]: ui32) kernel { + gpu.func @load_store(%src: ui64, %dst: ui32) kernel { + // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32 + // CHECK: %[[C0_I32:.*]] = arith.constant 0 + // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32 + // CHECK: %[[ARG1_IDX:.*]] = index.castu %[[ARG1]] : ui32 to index + // CHECK: %[[ARG1_I32:.*]] = arith.index_castui %[[ARG1_IDX]] : index to i32 + // CHECK: %[[ARG0_IDX:.*]] = index.castu %[[ARG0]] : ui64 to index + // CHECK: %[[ARG0_I64:.*]] = arith.index_castui %[[ARG0_IDX]] : index to i64 + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c1 = arith.constant 1 : index + %src_tdesc = xegpu.create_nd_tdesc %src, shape:[%c8, %c16], strides:[%c16, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32> + + + // CHECK: %[[VAR4:.*]] = llvm.inttoptr %[[ARG0_I64]] : i64 to !llvm.ptr<1> + // CHECK: %[[LOAD:.*]] = xevm.blockload2d %[[VAR4]], %[[C64_I32]], %[[C8_I32]], %[[C64_I32]], + // CHECK-SAME: %[[C0_I32]], %[[C0_I32]] <{cache_control = #xevm.load_cache_control, + // CHECK-SAME: elem_size_in_bits = 32 : i32, pack_register = false, tile_height = 8 : i32, + // CHECK-SAME: tile_width = 16 : i32, transpose = false, v_blocks = 1 : i32}> + // CHECK: %[[VAR6:.*]] = vector.bitcast %[[LOAD]] : vector<8xi32> to vector<8xf32> + %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32> + + %tid_x = gpu.thread_id x + %tid_x_i32 = arith.index_cast %tid_x : index to i32 + %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32 + // CHECK: %[[VAR9:.*]] = vector.insert + %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32> + + // CHECK: %[[VAR10:.*]] = arith.extui %[[ARG1_I32]] : i32 to i64 + %dst_tdesc = xegpu.create_nd_tdesc %dst, shape:[%c8, %c16], strides:[%c16, %c1] : ui32 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + + // CHECK: %[[VAR11:.*]] = llvm.inttoptr %[[VAR10]] : i64 to !llvm.ptr<1> + // CHECK: %[[STORE:.*]] = vector.bitcast %[[VAR9]] : vector<8xf32> to vector<8xi32> + // CHECK: xevm.blockstore2d %[[VAR11]], %[[C64_I32]], %[[C8_I32]], %[[C64_I32]], %[[C0_I32]], %[[C0_I32]], %[[STORE]] + // CHECK-SAME: <{cache_control = #xevm.store_cache_control, + // CHECK-SAME: elem_size_in_bits = 32 : i32, tile_height = 8 : i32, tile_width = 16 : i32}> + xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir index 873478aed57e3..09f2108cc5aed 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir @@ -1,40 +1,34 @@ -// RUN: mlir-opt -convert-xegpu-to-xevm -split-input-file %s | FileCheck %s +// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s -gpu.module @fence_check { - gpu.func @fence(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel { +gpu.module @prefetch_nd_check { + // CHECK-LABEL: gpu.func @prefetch_nd( + // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: memref<8x16xf32, 1>) kernel { + gpu.func @prefetch_nd(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel { + // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32> + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index + // CHECK: %[[VAR0:.*]] = arith.index_castui %[[INTPTR]] : index to i64 %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32> + // CHECK: %[[MEMSPACECAST_0:.*]] = memref.memory_space_cast %[[ARG1]] : memref<8x16xf32, 1> to memref<8x16xf32> %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32> - // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 - // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> - // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64> - // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32> - // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32> %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, - #xegpu.block_tdesc_attr, #xegpu.layout> - - //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64> - //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64> - //CHECK: %[[PREF_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32> - //CHECK: %[[PREF_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32> - //CHECK: %[[PREF_TILE_W64:.*]] = arith.constant 0 : i64 - //CHECK: %[[PREF_TILE_W:.*]] = arith.trunci %[[PREF_TILE_W64]] : i64 to i32 - //CHECK: %[[PREF_TILE_H64:.*]] = arith.constant 0 : i64 - //CHECK: %[[PREF_TILE_H:.*]] = arith.trunci %[[PREF_TILE_H64]] : i64 to i32 - //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1> - //CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32 - //CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[PREF_BASE_W]], %[[PREF_SIZEOF_F32]] : i32 - //CHECK: xevm.blockprefetch2d %[[PREF_LLVMPTR]], %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_BASE_H]], - //CHECK-SAME: %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_TILE_W]], %[[PREF_TILE_H]] - //CHECK-SAME: <{cache_control = #xevm.load_cache_control, elem_size_in_bits = 32 : i32, - //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> - //CHECK-SAME: : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + #xegpu.block_tdesc_attr> + // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64 + // CHECK: %[[VAR1:.*]] = arith.trunci %[[C16_I64]] : i64 to i32 + // CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64 + // CHECK: %[[VAR2:.*]] = arith.trunci %[[C8_I64]] : i64 to i32 + // CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64 + // CHECK: %[[VAR3:.*]] = arith.trunci %[[C0_I64]] : i64 to i32 + // CHECK: %[[C0_I64_1:.*]] = arith.constant 0 : i64 + // CHECK: %[[VAR4:.*]] = arith.trunci %[[C0_I64_1]] : i64 to i32 + // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR0]] : i64 to !llvm.ptr<1> + // CHECK: %[[C4_I32:.*]] = arith.constant 4 : i32 + // CHECK: %[[VAR6:.*]] = arith.muli %[[VAR1]], %[[C4_I32]] : i32 + // CHECK: xevm.blockprefetch2d %[[VAR5]], %[[VAR6]], %[[VAR2]], %[[VAR6]], %[[VAR3]], %[[VAR4]] + // CHECK-SAME: <{cache_control = #xevm.load_cache_control, elem_size_in_bits = 32 : i32, + // CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> xegpu.prefetch_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr, - #xegpu.layout> + : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> gpu.return }