diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 33e8f2ed1f6ed..9b4c620bf518d 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -186,9 +186,6 @@ class CreateNdDescToXeVMPattern SmallVector mixedSizes = op.getMixedSizes(); // Descriptor shape is expected to be 2D. int64_t rank = mixedSizes.size(); - if (rank != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D shape."); - auto sourceTy = source.getType(); auto sourceMemrefTy = dyn_cast(sourceTy); // If source is a memref, we need to extract the aligned pointer as index. @@ -199,8 +196,19 @@ class CreateNdDescToXeVMPattern } baseAddr = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); + // Cast index to i64. + baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); } else { baseAddr = adaptor.getSource(); + if (baseAddr.getType() != i64Ty) { + // Pointer type may be i32. Cast to i64 if needed. + baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); + } + } + // 1D tensor descriptor is just the base address. + if (rank == 1) { + rewriter.replaceOp(op, baseAddr); + return success(); } // Utility for creating offset values from op fold result. auto createOffset = [&](SmallVector &ofrVec, @@ -215,13 +223,6 @@ class CreateNdDescToXeVMPattern // Get shape values from op fold results. baseShapeW = createOffset(mixedSizes, 1); baseShapeH = createOffset(mixedSizes, 0); - if (sourceMemrefTy) { - // Cast index to i64. - baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); - } else if (baseAddr.getType() != i64Ty) { - // Pointer type may be i32. Cast to i64 if needed. - baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); - } // Populate payload. Value payLoadAsI64 = vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload); @@ -257,108 +258,175 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto mixedOffsets = op.getMixedOffsets(); int64_t opOffsetsSize = mixedOffsets.size(); - if (opOffsetsSize != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); auto tdesc = adaptor.getTensorDesc(); auto tdescTy = op.getTensorDescType(); - if (tdescTy.getRank() != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor."); + auto tileRank = tdescTy.getRank(); + if (opOffsetsSize != tileRank) + return rewriter.notifyMatchFailure( + op, "Expected offset rank to match descriptor rank."); auto elemType = tdescTy.getElementType(); auto elemBitSize = elemType.getIntOrFloatBitWidth(); if (elemBitSize % 8 != 0) return rewriter.notifyMatchFailure( op, "Expected element type bit width to be multiple of 8."); - VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); - Value payLoadAsI64 = - vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); - Value basePtr = vector::ExtractOp::create( - rewriter, loc, payLoadAsI64, static_cast(NdTdescOffset::BasePtr)); - Value baseShapeW = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast(NdTdescOffset::BaseShapeW)); - Value baseShapeH = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast(NdTdescOffset::BaseShapeH)); - // Offsets are provided by the op. - // convert them to i32. - Value offsetW = - getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); - offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, - rewriter.getI32Type(), offsetW); - Value offsetH = - getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); - offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, - rewriter.getI32Type(), offsetH); // Get address space from tensor descriptor memory space. auto ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); - // Convert base pointer (i64) to LLVM pointer type. - Value basePtrLLVM = - LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); - // Compute element byte size and surface width in bytes. - Value elemByteSize = arith::ConstantIntOp::create( - rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); - Value surfaceW = - arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); - - // Get tile sizes and vblocks from the tensor descriptor type. - auto tileW = tdescTy.getDimSize(1); - auto tileH = tdescTy.getDimSize(0); - int32_t vblocks = tdescTy.getArrayLength(); - if constexpr (std::is_same_v) { - Value src = adaptor.getValue(); - // If store value is a scalar, get value from op instead of adaptor. - // Adaptor might have optimized away single element vector - if (src.getType().isIntOrFloat()) { - src = op.getValue(); - } - VectorType srcVecTy = dyn_cast(src.getType()); - if (!srcVecTy) - return rewriter.notifyMatchFailure( - op, "Expected store value to be a vector type."); - // Get flat vector type of integer type with matching element bit size. - VectorType newSrcVecTy = - encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); - if (srcVecTy != newSrcVecTy) - src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); - auto storeCacheControl = - translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); - xevm::BlockStore2dOp::create( - rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, - offsetH, elemBitSize, tileW, tileH, src, - xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); - rewriter.eraseOp(op); - } else { - auto loadCacheControl = - translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); - if constexpr (std::is_same_v) { - xevm::BlockPrefetch2dOp::create( + if (tileRank == 2) { + // Compute element byte size. + Value elemByteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); + VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); + Value payLoadAsI64 = + vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); + Value basePtr = + vector::ExtractOp::create(rewriter, loc, payLoadAsI64, + static_cast(NdTdescOffset::BasePtr)); + Value baseShapeW = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast(NdTdescOffset::BaseShapeW)); + Value baseShapeH = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast(NdTdescOffset::BaseShapeH)); + // Offsets are provided by the op. + // convert them to i32. + Value offsetW = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); + offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetW); + Value offsetH = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetH); + // Convert base pointer (i64) to LLVM pointer type. + Value basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); + // Compute width in bytes. + Value surfaceW = + arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); + + // Get tile width from the tensor descriptor type. + auto tileW = tdescTy.getDimSize(tileRank - 1); + // Get tile height from the tensor descriptor type. + auto tileH = tdescTy.getDimSize(0); + // Get vblocks from the tensor descriptor type. + int32_t vblocks = tdescTy.getArrayLength(); + if constexpr (std::is_same_v) { + Value src = adaptor.getValue(); + // If store value is a scalar, get value from op instead of adaptor. + // Adaptor might have optimized away single element vector + if (src.getType().isIntOrFloat()) { + src = op.getValue(); + } + VectorType srcVecTy = dyn_cast(src.getType()); + if (!srcVecTy) + return rewriter.notifyMatchFailure( + op, "Expected store value to be a vector type."); + // Get flat vector type of integer type with matching element bit size. + VectorType newSrcVecTy = + encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); + if (srcVecTy != newSrcVecTy) + src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); + auto storeCacheControl = + translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + xevm::BlockStore2dOp::create( rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, - offsetH, elemBitSize, tileW, tileH, vblocks, - xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + offsetH, elemBitSize, tileW, tileH, src, + xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); rewriter.eraseOp(op); } else { - VectorType dstVecTy = cast(op.getValue().getType()); - const bool vnni = op.getPacked().value_or(false); - auto transposeValue = op.getTranspose(); - bool transpose = - transposeValue.has_value() && transposeValue.value()[0] == 1; - VectorType loadedTy = encodeVectorTypeTo( - dstVecTy, vnni ? rewriter.getI32Type() - : rewriter.getIntegerType(elemBitSize)); - - Value resultFlatVec = xevm::BlockLoad2dOp::create( - rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH, - surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, - transpose, vnni, + auto loadCacheControl = + translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + if constexpr (std::is_same_v) { + xevm::BlockPrefetch2dOp::create( + rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, + offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, + xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + rewriter.eraseOp(op); + } else { + VectorType dstVecTy = cast(op.getValue().getType()); + const bool vnni = op.getPacked().value_or(false); + auto transposeValue = op.getTranspose(); + bool transpose = + transposeValue.has_value() && transposeValue.value()[0] == 1; + VectorType loadedTy = encodeVectorTypeTo( + dstVecTy, vnni ? rewriter.getI32Type() + : rewriter.getIntegerType(elemBitSize)); + + Value resultFlatVec = xevm::BlockLoad2dOp::create( + rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH, + surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, + transpose, vnni, + xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + resultFlatVec = vector::BitCastOp::create( + rewriter, loc, + encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), + resultFlatVec); + rewriter.replaceOp(op, resultFlatVec); + } + } + } else { + // 1D tensor descriptor. + // `tdesc` represents base address as i64 + // Offset in number of elements, need to multiply by element byte size. + // Compute byte offset. + // byteOffset = offset * elementByteSize + Value offset = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + offset = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI64Type(), offset); + // Compute element byte size. + Value elemByteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI64Type(), elemBitSize / 8); + Value byteOffset = + rewriter.createOrFold(loc, offset, elemByteSize); + // Final address = basePtr + byteOffset + Value finalAddrI64 = rewriter.createOrFold( + loc, tdesc, + getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(), + byteOffset)); + // Convert base pointer (i64) to LLVM pointer type. + Value finalPtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64); + if constexpr (std::is_same_v) { + Value src = adaptor.getValue(); + // If store value is a scalar, get value from op instead of adaptor. + // Adaptor might have optimized away single element vector + if (src.getType().isIntOrFloat()) { + src = op.getValue(); + } + VectorType srcVecTy = dyn_cast(src.getType()); + if (!srcVecTy) + return rewriter.notifyMatchFailure( + op, "Expected store value to be a vector type."); + // Get flat vector type of integer type with matching element bit size. + VectorType newSrcVecTy = + encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); + if (srcVecTy != newSrcVecTy) + src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); + auto storeCacheControl = + translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + rewriter.replaceOpWithNewOp( + op, finalPtrLLVM, src, + xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); + } else if constexpr (std::is_same_v) { + auto loadCacheControl = + translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + VectorType resTy = cast(op.getValue().getType()); + VectorType loadedTy = + encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize)); + Value load = xevm::BlockLoadOp::create( + rewriter, loc, loadedTy, finalPtrLLVM, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); - resultFlatVec = vector::BitCastOp::create( - rewriter, loc, - encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), - resultFlatVec); - rewriter.replaceOp(op, resultFlatVec); + if (loadedTy != resTy) + load = vector::BitCastOp::create(rewriter, loc, resTy, load); + rewriter.replaceOp(op, load); + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported operation: xegpu.prefetch_nd with tensor " + "descriptor rank == 1"); } } return success(); @@ -927,7 +995,10 @@ struct ConvertXeGPUToXeVMPass return VectorType::get(sum, elemType); }); typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type { + // Scattered descriptors are not supported in XeVM lowering. if (type.isScattered()) + return {}; + if (type.getRank() == 1) return IntegerType::get(&getContext(), 64); auto i32Type = IntegerType::get(&getContext(), 32); return VectorType::get(8, i32Type); diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir index 09ef76c9d1740..109312218afae 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir @@ -29,13 +29,13 @@ gpu.module @create_nd_tdesc { // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32> // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index + // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32 // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32 // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64 // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32 // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64 // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32 - // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64> // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64> // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32> @@ -53,11 +53,11 @@ gpu.module @create_nd_tdesc { %BLOCK_DMODEL = arith.constant 16 : index // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32> // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref -> index + // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64 // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32 // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32 // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32 // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32 - // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64 // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64> // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64> // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32> diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir new file mode 100644 index 0000000000000..7b4ad9ec2df03 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s + +gpu.module @load_store_check { + // CHECK-LABEL: @load_store( + // CHECK-SAME: %[[SRC:.*]]: memref<512xf32, 1>, %[[DST:.*]]: memref<256xf32, 1> + gpu.func @load_store(%src: memref<512xf32, 1>, %dst: memref<256xf32, 1>) kernel { + // CHECK: %[[C512:.*]] = arith.constant 512 : i64 + // CHECK: %[[C384:.*]] = arith.constant 384 : i64 + + // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[SRC]] : memref<512xf32, 1> to memref<512xf32> + %srcce = memref.memory_space_cast %src : memref<512xf32, 1> to memref<512xf32> + // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[DST]] : memref<256xf32, 1> to memref<256xf32> + %dstte = memref.memory_space_cast %dst : memref<256xf32, 1> to memref<256xf32> + + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<512xf32> -> index + // CHECK: %[[INTPTR_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<512xf32> -> !xegpu.tensor_desc<32xf32> + // CHECK: %[[ADDR:.*]] = arith.addi %[[INTPTR_I64]], %[[C384]] : i64 + // CHECK: %[[PTR:.*]] = llvm.inttoptr %[[ADDR]] : i64 to !llvm.ptr<1> + // CHECK: %[[LOAD:.*]] = xevm.blockload %[[PTR]] <{cache_control = #xevm.load_cache_control}> + // CHECK-SAME: : (!llvm.ptr<1>) -> vector<2xi32> + %loaded = xegpu.load_nd %src_tdesc[96] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<32xf32> -> vector<2xf32> + + // CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<256xf32> -> index + // CHECK: %[[INTPTR1_I64:.*]] = arith.index_castui %[[INTPTR1]] : index to i64 + %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<256xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr> + // CHECK: %[[ADDR1:.*]] = arith.addi %[[INTPTR1_I64]], %[[C512]] : i64 + // CHECK: %[[PTR1:.*]] = llvm.inttoptr %[[ADDR1]] : i64 to !llvm.ptr<1> + // CHECK: xevm.blockstore %[[PTR1]], %[[LOAD]] <{cache_control = #xevm.store_cache_control}> + // CHECK-SAME: : (!llvm.ptr<1>, vector<2xi32>) + xegpu.store_nd %loaded, %dst_tdesc[128] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<2xf32>, !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir index 4c6bbf25b4728..95774ca67c4f2 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir @@ -16,6 +16,7 @@ gpu.module @load_store_check { %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32 //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64> //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64> //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32> @@ -25,7 +26,6 @@ gpu.module @load_store_check { //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64 //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32 //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1> - //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32 //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32 //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]], //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]] @@ -52,6 +52,7 @@ gpu.module @load_store_check { // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32> %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32 //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64> //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64> //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32> @@ -61,7 +62,6 @@ gpu.module @load_store_check { //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64 //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32 //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1> - //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32 //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32 //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32> //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]], diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir index 873478aed57e3..ac13ae9013593 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir @@ -16,6 +16,7 @@ gpu.module @fence_check { %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr, #xegpu.layout> + //CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32 //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64> //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64> //CHECK: %[[PREF_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32> @@ -25,7 +26,6 @@ gpu.module @fence_check { //CHECK: %[[PREF_TILE_H64:.*]] = arith.constant 0 : i64 //CHECK: %[[PREF_TILE_H:.*]] = arith.trunci %[[PREF_TILE_H64]] : i64 to i32 //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1> - //CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32 //CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[PREF_BASE_W]], %[[PREF_SIZEOF_F32]] : i32 //CHECK: xevm.blockprefetch2d %[[PREF_LLVMPTR]], %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_BASE_H]], //CHECK-SAME: %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_TILE_W]], %[[PREF_TILE_H]]