From 9428381f89fb83dd872e67819c8d2ac2c74150eb Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Mon, 14 Jul 2025 18:54:41 +0000 Subject: [PATCH 01/18] Add XeGPUToXeVM conversion pass and tests. --- mlir/include/mlir/Conversion/Passes.h | 1 + mlir/include/mlir/Conversion/Passes.td | 12 + .../mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h | 27 + mlir/lib/Conversion/CMakeLists.txt | 1 + .../lib/Conversion/XeGPUToXeVM/CMakeLists.txt | 25 + .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 932 ++++++++++++++++++ .../XeGPUToXeVM/create_nd_tdesc.mlir | 48 + mlir/test/Conversion/XeGPUToXeVM/dpas.mlir | 17 + mlir/test/Conversion/XeGPUToXeVM/fence.mlir | 15 + .../Conversion/XeGPUToXeVM/loadstore_nd.mlir | 71 ++ .../XeGPUToXeVM/loadstoreprefetch.mlir | 357 +++++++ .../Conversion/XeGPUToXeVM/prefetch_nd.mlir | 40 + .../Conversion/XeGPUToXeVM/update_offset.mlir | 25 + 13 files changed, 1571 insertions(+) create mode 100644 mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h create mode 100644 mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt create mode 100644 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp create mode 100644 mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir create mode 100644 mlir/test/Conversion/XeGPUToXeVM/dpas.mlir create mode 100644 mlir/test/Conversion/XeGPUToXeVM/fence.mlir create mode 100644 mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir create mode 100644 mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir create mode 100644 mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir create mode 100644 mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 91b2ecf8922a3..da061b269daf7 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -82,6 +82,7 @@ #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h" #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h" +#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h" #include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" namespace mlir { diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 2058aba7f9e37..323af3e97e2d4 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1555,4 +1555,16 @@ def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> { let dependentDialects = ["LLVM::LLVMDialect"]; } +//===----------------------------------------------------------------------===// +// XeGPUToXeVM +//===----------------------------------------------------------------------===// + +def ConvertXeGPUToXeVMPass : Pass<"convert-xegpu-to-xevm"> { + let summary = "Convert XeGPU to XeVM dialect"; + let dependentDialects = ["xevm::XeVMDialect", "vector::VectorDialect", + "memref::MemRefDialect", "arith::ArithDialect", + "LLVM::LLVMDialect", "index::IndexDialect", + "gpu::GPUDialect", "scf::SCFDialect"]; +} + #endif // MLIR_CONVERSION_PASSES diff --git a/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h new file mode 100644 index 0000000000000..fb23d24b0161b --- /dev/null +++ b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h @@ -0,0 +1,27 @@ +//===-- XeGPUToXeVM.h - Convert XeGPU to XeVM dialect ---------_--*- C++-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_ +#define MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_ + +#include + +namespace mlir { +class DialectRegistry; +class LLVMTypeConverter; +class RewritePatternSet; +class Pass; + +#define GEN_PASS_DECL_CONVERTXEGPUTOXEVMPASS +#include "mlir/Conversion/Passes.h.inc" + +void populateXeGPUToXeVMConversionPatterns( + mlir::RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter); + +} // namespace mlir + +#endif // MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_ diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 171f7169fd41d..134fe8e14ca38 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -76,3 +76,4 @@ add_subdirectory(VectorToSCF) add_subdirectory(VectorToSPIRV) add_subdirectory(VectorToXeGPU) add_subdirectory(XeVMToLLVM) +add_subdirectory(XeGPUToXeVM) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt new file mode 100644 index 0000000000000..ed54b0bb5ee81 --- /dev/null +++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt @@ -0,0 +1,25 @@ +add_mlir_conversion_library(MLIRXeGPUToXeVM + XeGPUToXeVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeGPUToXeVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRFuncDialect + MLIRGPUDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRXeVMDialect + MLIRVectorDialect + MLIRArithDialect + MLIRIndexDialect + MLIRXeGPUDialect + MLIRPass + MLIRTransforms +) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp new file mode 100644 index 0000000000000..380409afbc62e --- /dev/null +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -0,0 +1,932 @@ +//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/FormatVariadic.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" + +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { + +enum class NdDescI32Layout : uint32_t { + BasePtr = 0, + BaseShapeW = 2, + BaseShapeH = 3, + TensorOffsetW = 4, + TensorOffsetH = 5 +}; + +static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { + switch (xeGpuMemspace) { + case xegpu::MemorySpace::Global: + return static_cast(xevm::AddrSpace::GLOBAL); + case xegpu::MemorySpace::SLM: + return static_cast(xevm::AddrSpace::SHARED); + } + llvm_unreachable("Unknown XeGPU memory space."); +} + +template +std::tuple checkAllLinear(SmallVector denseAttr) { + assert(!denseAttr.empty()); + const int32_t intercept{static_cast(denseAttr[0])}; + if (denseAttr.size() < 2) + return {true, 0, intercept}; + const T slope{denseAttr[1] - denseAttr[0]}; + for (size_t i = 1; i < denseAttr.size(); ++i) + if (denseAttr[i] - denseAttr[i - 1] != slope) + return {false, 0, 0}; + return {true, static_cast(slope), intercept}; +} + +VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) { + auto elemType = currentVecType.getElementType(); + auto currentBitWidth = elemType.getIntOrFloatBitWidth(); + auto newBitWidth = toElemType.getIntOrFloatBitWidth(); + const int size = + currentVecType.getNumElements() * currentBitWidth / newBitWidth; + return VectorType::get(size, toElemType); +} + +xevm::LoadCacheControl +translateLoadXeGPUCacheHint(std::optional L1hint, + std::optional L3hint) { + auto L1hintVal = + L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED; + auto L3hintVal = + L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED; + switch (L1hintVal) { + case xegpu::CachePolicy::CACHED: + if (L3hintVal == xegpu::CachePolicy::CACHED) + return xevm::LoadCacheControl::L1C_L2UC_L3C; + else if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::LoadCacheControl::L1C_L2UC_L3UC; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::UNCACHED: + if (L3hintVal == xegpu::CachePolicy::CACHED) + return xevm::LoadCacheControl::L1UC_L2UC_L3C; + else if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::LoadCacheControl::L1UC_L2UC_L3UC; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::STREAMING: + if (L3hintVal == xegpu::CachePolicy::CACHED) + return xevm::LoadCacheControl::L1S_L2UC_L3C; + else if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::LoadCacheControl::L1S_L2UC_L3UC; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::READ_INVALIDATE: + return xevm::LoadCacheControl::INVALIDATE_READ; + default: + llvm_unreachable("Unsupported cache control."); + } +} + +xevm::StoreCacheControl +translateStoreXeGPUCacheHint(std::optional L1hint, + std::optional L3hint) { + auto L1hintVal = + L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED; + auto L3hintVal = + L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED; + switch (L1hintVal) { + case xegpu::CachePolicy::UNCACHED: + if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::StoreCacheControl::L1UC_L2UC_L3UC; + else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) + return xevm::StoreCacheControl::L1UC_L2UC_L3WB; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::STREAMING: + if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::StoreCacheControl::L1S_L2UC_L3UC; + else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) + return xevm::StoreCacheControl::L1S_L2UC_L3WB; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::WRITE_BACK: + if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::StoreCacheControl::L1WB_L2UC_L3UC; + else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) + return xevm::StoreCacheControl::L1WB_L2UC_L3WB; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::WRITE_THROUGH: + if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::StoreCacheControl::L1WT_L2UC_L3UC; + else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) + return xevm::StoreCacheControl::L1WT_L2UC_L3WB; + else + llvm_unreachable("Unsupported cache control."); + default: + llvm_unreachable("Unsupported cache control."); + } +} + +class CreateNdDescToXeVMPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::CreateNdDescOp op, + xegpu::CreateNdDescOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto source = op.getSource(); + Type payloadElemTy = rewriter.getI32Type(); + Type i64Ty = rewriter.getI64Type(); + VectorType payloadTy = VectorType::get(8, payloadElemTy); + VectorType payloadI64Ty = VectorType::get(4, i64Ty); + Value payload = arith::ConstantOp::create( + rewriter, loc, + DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0))); + + Value baseAddr; + Value baseShapeW; + Value baseShapeH; + Value offsetW; + Value offsetH; + + bool sourceIsMemref = false; + auto sourceTy = source.getType(); + int64_t rank; + if (isa(sourceTy)) { + sourceIsMemref = true; + baseAddr = + memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); + auto sourceMemrefTy = cast(sourceTy); + if (!sourceMemrefTy.hasStaticShape()) { + op.emitError() << "Expected static memref shape."; + return failure(); + } + rank = sourceMemrefTy.getRank(); + if (rank != 2) { + op.emitError() << "Expected a 2D memref."; + return failure(); + } + } else if (sourceTy == rewriter.getIntegerType(64, false)) { + rank = op.getMixedSizes().size(); + } else { + op.emitError() << "Expected source to be a 2D memref or ui64."; + return failure(); + } + auto createOffset = [&](unsigned idx) -> Value { + Value val; + OpFoldResult ofr = op.getMixedOffsets()[idx]; + if (auto v = llvm::dyn_cast_if_present(ofr)) { + val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v); + val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val); + } else { + int32_t off = llvm::cast(cast(ofr)).getInt(); + val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off); + } + return val; + }; + auto offsets = op.getMixedOffsets(); + if (offsets.size() == 2) { + offsetW = createOffset(rank - 1); + offsetH = createOffset(rank - 2); + } else { + offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); + offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); + } + auto createShape = [&](unsigned idx) -> Value { + Value val; + OpFoldResult ofr = op.getMixedSizes()[idx]; + if (auto v = llvm::dyn_cast_if_present(ofr)) { + val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v); + val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val); + } else { + int32_t off = llvm::cast(cast(ofr)).getInt(); + val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off); + } + return val; + }; + if (sourceIsMemref) { + auto sourceMemrefTy = cast(sourceTy); + baseShapeW = arith::ConstantIntOp::create( + rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 1)); + baseShapeH = arith::ConstantIntOp::create( + rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 2)); + baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); + } else { + baseShapeW = createShape(rank - 1); + baseShapeH = createShape(rank - 2); + baseAddr = adaptor.getSource(); + } + Value payLoadAsI64 = + vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload); + payLoadAsI64 = + vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64, + static_cast(NdDescI32Layout::BasePtr)); + payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64); + payload = + vector::InsertOp::create(rewriter, loc, baseShapeW, payload, + static_cast(NdDescI32Layout::BaseShapeW)); + payload = + vector::InsertOp::create(rewriter, loc, baseShapeH, payload, + static_cast(NdDescI32Layout::BaseShapeH)); + payload = vector::InsertOp::create( + rewriter, loc, offsetW, payload, + static_cast(NdDescI32Layout::TensorOffsetW)); + payload = vector::InsertOp::create( + rewriter, loc, offsetH, payload, + static_cast(NdDescI32Layout::TensorOffsetH)); + rewriter.replaceOp(op, payload); + return success(); + } +}; + +class UpdateNdOffsetToXeVMPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::UpdateNdOffsetOp op, + xegpu::UpdateNdOffsetOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto offsets = op.getOffsets(); + auto tdesc = adaptor.getTensorDesc(); + for (size_t offsetDim = 0; offsetDim < offsets.size(); offsetDim++) { + auto offset = offsets[offsetDim]; + if (auto cst = + dyn_cast_if_present(offset.getDefiningOp())) + if (auto attr = dyn_cast_if_present(cst.getValue()); + attr && !attr.getInt()) + continue; + const int offsetPos = + static_cast(offsetDim ? NdDescI32Layout::TensorOffsetW + : NdDescI32Layout::TensorOffsetH); + auto oldOffset = + vector::ExtractOp::create(rewriter, loc, tdesc, offsetPos); + offset = arith::IndexCastUIOp::create(rewriter, loc, + rewriter.getI32Type(), offset); + auto newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset); + tdesc = + vector::InsertOp::create(rewriter, loc, newOffset, tdesc, offsetPos); + } + rewriter.replaceOp(op, tdesc); + return success(); + } +}; + +template < + typename OpType, + typename = std::enable_if_t::value>> +class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + + auto tdesc = adaptor.getTensorDesc(); + auto tdescTy = op.getTensorDescType(); + if (tdescTy.getRank() != 2) { + return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor."); + } + + VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); + Value payLoadAsI64 = + vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); + Value basePtr = + vector::ExtractOp::create(rewriter, loc, payLoadAsI64, + static_cast(NdDescI32Layout::BasePtr)); + Value baseShapeW = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast(NdDescI32Layout::BaseShapeW)); + Value baseShapeH = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast(NdDescI32Layout::BaseShapeH)); + // Offsets can come from three sources: + // 1. Constant offsets, which are provided by the op. + // 2. Offsets as operands, which are provided by the op. + // 3. Offsets extracted from the tensor descriptor. + Value offsetW; + Value offsetH; + auto cOffsets = op.getConstOffsets(); + auto offsets = op.getOffsets(); + if (cOffsets) { + offsetW = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI32Type(), (*cOffsets)[0]); + offsetH = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI32Type(), (*cOffsets)[1]); + } else if (offsets.size() != 0) { + // offsets are provided as operands + if (offsets[0].getType() != rewriter.getI32Type()) { + if (offsets[0].getType() != rewriter.getIndexType()) { + return rewriter.notifyMatchFailure( + op, "Expected offsets to be of type i32 or index."); + } + offsetW = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), offsets[0]); + } else { + offsetW = offsets[0]; + } + if (offsets[1].getType() != rewriter.getI32Type()) { + if (offsets[1].getType() != rewriter.getIndexType()) { + return rewriter.notifyMatchFailure( + op, "Expected offsets to be of type i32 or index."); + } + offsetH = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), offsets[1]); + } else { + offsetH = offsets[1]; + } + } else { + // If offsets are not available, we need to extract them from the tensor + // descriptor. + offsetW = vector::ExtractOp::create( + rewriter, loc, tdesc, + static_cast(NdDescI32Layout::TensorOffsetW)); + offsetH = vector::ExtractOp::create( + rewriter, loc, tdesc, + static_cast(NdDescI32Layout::TensorOffsetH)); + } + auto ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); + Value basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); + auto elemType = tdescTy.getElementType(); + auto elemBitSize = elemType.getIntOrFloatBitWidth(); + // auto elemBitSizeAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), + // elemBitSize); + Value elemByteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); + Value surfaceW = + arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); + + auto tileW = tdescTy.getDimSize(1); + auto tileH = tdescTy.getDimSize(0); + int32_t vblocks = tdescTy.getArrayLength(); + if constexpr (std::is_same_v) { + VectorType srcVecTy = cast(op.getValue().getType()); + auto storeCacheControl = + translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + VectorType srcFlatVecTy = + VectorType::get(srcVecTy.getNumElements(), srcVecTy.getElementType()); + Value srcFlatVec = op.getValue(); + srcFlatVecTy = encodeVectorTypeTo(srcFlatVecTy, + rewriter.getIntegerType(elemBitSize)); + srcFlatVec = + vector::BitCastOp::create(rewriter, loc, srcFlatVecTy, srcFlatVec); + xevm::BlockStore2dOp::create( + rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, + offsetH, elemBitSize, tileW, tileH, srcFlatVec, + xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); + rewriter.eraseOp(op); + } else { + auto loadCacheControl = + translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + if constexpr (std::is_same_v) { + xevm::BlockPrefetch2dOp::create( + rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, + offsetH, elemBitSize, tileW, tileH, vblocks, + xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + rewriter.eraseOp(op); + } else { + VectorType dstVecTy = cast(op.getValue().getType()); + const bool vnni = op.getPacked().value_or(false); + auto transposeValue = op.getTranspose(); + bool transpose = + transposeValue.has_value() && transposeValue.value()[0] == 1; + VectorType loadedTy = encodeVectorTypeTo( + dstVecTy, vnni ? rewriter.getI32Type() + : rewriter.getIntegerType(elemBitSize)); + + Value resultFlatVec = xevm::BlockLoad2dOp::create( + rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH, + surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, + transpose, vnni, + xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + resultFlatVec = vector::BitCastOp::create( + rewriter, loc, + encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), + resultFlatVec); + rewriter.replaceOp(op, resultFlatVec); + } + } + return success(); + } +}; + +template < + typename OpType, + typename = std::enable_if_t::value>> +int64_t getElemByteSize(OpType op) { + // Get the element byte size from the tensor descriptor. + auto elemBitWidth = + op.getTensorDesc().getType().getElementType().getIntOrFloatBitWidth(); + return elemBitWidth / 8; +} + +// Add a builder that creates +// offset * elemByteSize + baseAddr +auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc, + Value baseAddr, Value offset, + int64_t elemByteSize) -> Value { + Value byteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI64Type(), elemByteSize); + Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); + Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset); + return newAddr; +}; + +class CreateDescToXeVMPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto offsets = adaptor.getOffsets(); + // Source type can be a 1D memref or ui64 + // Using "op" instead of "adaptor" since we want to access memref type + // instead of LLVM struct type. + auto memrefTy = dyn_cast(op.getSource().getType()); + Value subGroupAddr; + if (memrefTy) { + subGroupAddr = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, loc, op.getSource()); + subGroupAddr = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI64Type(), subGroupAddr); + } else { + subGroupAddr = adaptor.getSource(); + } + auto laneAddr = + addOffset(rewriter, loc, subGroupAddr, offsets, getElemByteSize(op)); + rewriter.replaceOp(op, laneAddr); + return success(); + } +}; + +class UpdateOffsetToXeVMPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::UpdateOffsetOp op, + xegpu::UpdateOffsetOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value newOffsetForLane = + addOffset(rewriter, loc, adaptor.getTensorDesc(), adaptor.getOffsets(), + getElemByteSize(op)); + rewriter.replaceOp(op, newOffsetForLane); + return success(); + } +}; + +template ::value>> +class LoadStoreToXeVMPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + auto tdescTy = op.getTensorDescType(); + auto ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); + Value basePtrI64; + if constexpr (std::is_same_v) { + basePtrI64 = adaptor.getSource(); + } else { + basePtrI64 = adaptor.getDest(); + } + Value offsets = adaptor.getOffsets(); + Value mask = adaptor.getMask(); + if (offsets) { + VectorType offsetsVecTy = dyn_cast(offsets.getType()); + if (offsetsVecTy) { + // Offset needs be scalar. + return rewriter.notifyMatchFailure(op, + "Expected offsets to be a scalar."); + } else { + basePtrI64 = + addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op)); + } + } + Value basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); + VectorType srcOrDstVecTy = op.getValueType(); + VectorType srcOrDstFlatVecTy = VectorType::get( + srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType()); + Value maskForLane; + VectorType maskVecTy = dyn_cast(mask.getType()); + if (maskVecTy) { + return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar."); + } else + maskForLane = mask; + if constexpr (std::is_same_v) { + scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {srcOrDstVecTy}, + maskForLane, true, true); + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value loaded = + LLVM::LoadOp::create(rewriter, loc, srcOrDstFlatVecTy, basePtrLLVM); + loaded.getDefiningOp()->setAttr("cache_control", + xevm::LoadCacheControlAttr::get( + ctxt, translateLoadXeGPUCacheHint( + op.getL1Hint(), op.getL3Hint()))); + if (srcOrDstVecTy != srcOrDstFlatVecTy) { + loaded = + vector::ShapeCastOp::create(rewriter, loc, srcOrDstVecTy, loaded); + } + scf::YieldOp::create(rewriter, loc, ValueRange{loaded}); + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + // If mask is false, we yield a vector of zeros. + auto eTy = srcOrDstVecTy.getElementType(); + loaded = arith::ConstantOp::create( + rewriter, loc, + eTy.isFloat() + ? DenseElementsAttr::get(srcOrDstVecTy, FloatAttr::get(eTy, 0.0)) + : DenseElementsAttr::get(srcOrDstVecTy, + IntegerAttr::get(eTy, 0))); + scf::YieldOp::create(rewriter, loc, ValueRange{loaded}); + rewriter.replaceOp(op, ifOp.getResult(0)); + } else { + scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false); + auto body = ifOp.getBody(); + rewriter.setInsertionPointToStart(body); + VectorType valTy = op.getValue().getType(); + Value srcFlatVec = op.getValue(); + if (valTy != srcOrDstFlatVecTy) { + srcFlatVec = vector::ShapeCastOp::create(rewriter, loc, + srcOrDstFlatVecTy, srcFlatVec); + } + auto storeOp = LLVM::StoreOp::create(rewriter, loc, srcFlatVec, basePtrLLVM); + storeOp.getOperation()->setAttr( + "cache_control", + xevm::StoreCacheControlAttr::get(ctxt, + translateStoreXeGPUCacheHint( + op.getL1Hint(), op.getL3Hint()))); + rewriter.eraseOp(op); + } + return success(); + } +}; + +class PrefetchToXeVMPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + auto tdescTy = op.getTensorDescType(); + auto ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); + Value basePtrI64 = adaptor.getSource(); + Value offsets = adaptor.getOffsets(); + if (offsets) { + VectorType offsetsVecTy = dyn_cast(offsets.getType()); + if (offsetsVecTy) { + // Offset needs be scalar. + return rewriter.notifyMatchFailure(op, + "Expected offsets to be a scalar."); + } else { + basePtrI64 = + addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op)); + } + } + Value ptrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); + xevm::PrefetchOp::create( + rewriter, loc, ptrLLVM, + xevm::LoadCacheControlAttr::get( + ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()))); + rewriter.eraseOp(op); + return success(); + } +}; +class FenceToXeVMPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + xevm::MemScope memScope{xevm::MemScope::WORKGROUP}; + switch (op.getFenceScope()) { + case xegpu::FenceScope::Workgroup: + memScope = xevm::MemScope::WORKGROUP; + break; + case xegpu::FenceScope::GPU: + memScope = xevm::MemScope::DEVICE; + break; + llvm_unreachable("Unknown XeGPU fence scope."); + } + xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL}; + switch (op.getMemoryKind()) { + case xegpu::MemorySpace::Global: + addrSpace = xevm::AddrSpace::GLOBAL; + break; + case xegpu::MemorySpace::SLM: + addrSpace = xevm::AddrSpace::SHARED; + break; + llvm_unreachable("Unknown XeGPU fence scope."); + } + xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace); + rewriter.eraseOp(op); + return success(); + } +}; + +class DpasToXeVMPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + auto aTy = cast(op.getLhs().getType()); + auto bTy = cast(op.getRhs().getType()); + auto resultType = cast(op.getResultType()); + + auto encodePrecision = [&](Type type) -> xevm::ElemType { + if (type == rewriter.getBF16Type()) + return xevm::ElemType::BF16; + else if (type == rewriter.getF16Type()) + return xevm::ElemType::F16; + else if (type == rewriter.getTF32Type()) + return xevm::ElemType::TF32; + else if (type.isInteger(8)) { + if (type.isUnsignedInteger()) + return xevm::ElemType::U8; + return xevm::ElemType::S8; + } else if (type == rewriter.getF32Type()) + return xevm::ElemType::F32; + else if (type.isInteger(32)) + return xevm::ElemType::S32; + llvm_unreachable("add more support for ElemType"); + }; + xevm::ElemType precATy = encodePrecision(aTy.getElementType()); + xevm::ElemType precBTy = encodePrecision(bTy.getElementType()); + // auto precA = xevm::ElemTypeAttr::get(ctxt, precATy); + // auto precB = xevm::ElemTypeAttr::get(ctxt, precBTy); + Value c = op.getAcc(); + if (!c) { + auto elementTy = resultType.getElementType(); + Attribute initValueAttr; + if (isa(elementTy)) + initValueAttr = FloatAttr::get(elementTy, 0.0); + else + initValueAttr = IntegerAttr::get(elementTy, 0); + c = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr)); + } + + Value aVec = op.getLhs(); + Value bVec = op.getRhs(); + auto cvecty = cast(c.getType()); + xevm::ElemType precCTy = encodePrecision(cvecty.getElementType()); + xevm::ElemType precDTy = encodePrecision(resultType.getElementType()); + // auto precC = xevm::ElemTypeAttr::get(ctxt, precCTy); + // auto precD = xevm::ElemTypeAttr::get(ctxt, precDTy); + VectorType cNty = + VectorType::get(cvecty.getNumElements(), cvecty.getElementType()); + if (cvecty != cNty) + c = vector::ShapeCastOp::create(rewriter, loc, cNty, c); + // below are uArch dependent values, should move away from hardcoding + constexpr int32_t systolicDepth{8}; + constexpr int32_t executionSize{16}; + Value dpasRes = xevm::MMAOp::create( + rewriter, loc, cNty, aVec, bVec, c, + xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize, + systolicDepth * + getNumOperandsPerDword(precATy)), + xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy)); + if (cvecty != cNty) + dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes); + rewriter.replaceOp(op, dpasRes); + return success(); + } + +private: + static unsigned getNumOperandsPerDword(xevm::ElemType pTy) { + switch (pTy) { + case xevm::ElemType::TF32: + return 1; + case xevm::ElemType::BF16: + case xevm::ElemType::F16: + return 2; + case xevm::ElemType::U8: + case xevm::ElemType::S8: + return 4; + default: + llvm_unreachable("unsupported xevm::ElemType"); + } + } +}; + +static std::optional +matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) { + switch (arithKind) { + case arith::AtomicRMWKind::addf: + return LLVM::AtomicBinOp::fadd; + case arith::AtomicRMWKind::addi: + return LLVM::AtomicBinOp::add; + case arith::AtomicRMWKind::assign: + return LLVM::AtomicBinOp::xchg; + case arith::AtomicRMWKind::maximumf: + return LLVM::AtomicBinOp::fmax; + case arith::AtomicRMWKind::maxs: + return LLVM::AtomicBinOp::max; + case arith::AtomicRMWKind::maxu: + return LLVM::AtomicBinOp::umax; + case arith::AtomicRMWKind::minimumf: + return LLVM::AtomicBinOp::fmin; + case arith::AtomicRMWKind::mins: + return LLVM::AtomicBinOp::min; + case arith::AtomicRMWKind::minu: + return LLVM::AtomicBinOp::umin; + case arith::AtomicRMWKind::ori: + return LLVM::AtomicBinOp::_or; + case arith::AtomicRMWKind::andi: + return LLVM::AtomicBinOp::_and; + default: + return std::nullopt; + } + llvm_unreachable("Invalid AtomicRMWKind"); +} + +class AtomicRMWToXeVMPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + auto tdesc = op.getTensorDesc().getType(); + auto ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace())); + Value basePtrI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc()); + Value basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); + VectorType srcOrDstVecTy = cast(op.getValue().getType()); + VectorType srcOrDstFlatVecTy = VectorType::get( + srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType()); + Value srcFlatVec = vector::ShapeCastOp::create( + rewriter, loc, srcOrDstFlatVecTy, op.getValue()); + auto atomicKind = matchSimpleAtomicOp(op.getKind()); + assert(atomicKind.has_value()); + Value resVec = srcFlatVec; + for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) { + auto val = vector::ExtractOp::create(rewriter, loc, resVec, i); + Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), + rewriter.getIndexAttr(i)); + Value currPtr = + LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM, + srcOrDstVecTy.getElementType(), basePtrLLVM, idx); + Value newVal = + LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr, + val, LLVM::AtomicOrdering::seq_cst); + resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i); + } + rewriter.replaceOp(op, resVec); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +struct ConvertXeGPUToXeVMPass + : public impl::ConvertXeGPUToXeVMPassBase { + using Base::Base; + + void runOnOperation() override { + LLVMTypeConverter typeConverter(&getContext()); + typeConverter.addConversion([&](VectorType type) -> Type { + unsigned rank = type.getRank(); + auto elemType = type.getElementType(); + // If the element type is index, convert it to i64. + if (llvm::isa(elemType)) + elemType = IntegerType::get(&getContext(), 64); + // If the vector is a scalar or has a single element, return the element + if (rank < 1 || type.getNumElements() == 1) + return elemType; + // Otherwise, convert the vector to a flat vector type. + unsigned sum = 1; + for (unsigned i = 0; i < rank; i++) { + sum *= type.getShape()[i]; + } + return VectorType::get(sum, elemType); + }); + typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type { + if (type.isScattered()) { + return IntegerType::get(&getContext(), 64); + } + auto i32Type = IntegerType::get(&getContext(), 32); + return VectorType::get(8, i32Type); + }); + + auto ui64MaterializationCast = [](OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (input.getType() == builder.getIntegerType(64, false)) { + Value cast = + index::CastUOp::create(builder, loc, builder.getIndexType(), input) + .getResult(); + return arith::IndexCastOp::create(builder, loc, type, cast).getResult(); + } + return {}; + }; + + auto vector1DMaterializationCast = [](OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (auto vecTy = dyn_cast(input.getType())) { + if (vecTy.getNumElements() == 1) { + // If the vector has a single element, return the element type. + Value cast = + vector::ExtractOp::create(builder, loc, input, 0).getResult(); + if (vecTy.getElementType() == builder.getIndexType()) + cast = arith::IndexCastOp::create(builder, loc, type, cast) + .getResult(); + return cast; + } + } + return {}; + }; + typeConverter.addSourceMaterialization(ui64MaterializationCast); + typeConverter.addSourceMaterialization(vector1DMaterializationCast); + typeConverter.addTargetMaterialization(ui64MaterializationCast); + typeConverter.addTargetMaterialization(vector1DMaterializationCast); + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalDialect(); + + RewritePatternSet patterns(&getContext()); + populateXeGPUToXeVMConversionPatterns(patterns, typeConverter); + scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, + patterns, target); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// +void mlir::populateXeGPUToXeVMConversionPatterns( + RewritePatternSet &patterns, LLVMTypeConverter &typeConverter) { + patterns.add, + LoadStorePrefetchNdToXeVMPattern, + LoadStorePrefetchNdToXeVMPattern>( + typeConverter, patterns.getContext()); + patterns.add, + LoadStoreToXeVMPattern>( + typeConverter, patterns.getContext()); + patterns.add(typeConverter, + patterns.getContext()); +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir new file mode 100644 index 0000000000000..4fba920f023c4 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s + +gpu.module @create_nd_tdesc { + // CHECK-LABEL: gpu.func @create_nd_tdesc + // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64 + // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index + gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index, + %stride1: index, %stride2: index) kernel { + // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index + // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> + // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32 + // CHECK: %[[C0_I32_0:.*]] = arith.constant 0 : i32 + // CHECK: %[[VAR2:.*]] = arith.index_cast %[[ARG3]] : index to i64 + // CHECK: %[[VAR3:.*]] = arith.trunci %[[VAR2]] : i64 to i32 + // CHECK: %[[VAR4:.*]] = arith.index_cast %[[ARG2]] : index to i64 + // CHECK: %[[VAR5:.*]] = arith.trunci %[[VAR4]] : i64 to i32 + // CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[VAR7:.*]] = vector.insert %[[VAR1]], %[[VAR6]] [0] : i64 into vector<4xi64> + // CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[VAR9:.*]] = vector.insert %[[VAR3]], %[[VAR8]] [2] : i32 into vector<8xi32> + // CHECK: %[[VAR10:.*]] = vector.insert %[[VAR5]], %[[VAR9]] [3] : i32 into vector<8xi32> + // CHECK: %[[VAR11:.*]] = vector.insert %[[C0_I32]], %[[VAR10]] [4] : i32 into vector<8xi32> + // CHECK: %[[VAR12:.*]] = vector.insert %[[C0_I32_0]], %[[VAR11]] [5] : i32 into vector<8xi32> + %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2] + : ui64 -> !xegpu.tensor_desc<8x16xf32> + + // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32> + %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32> + + // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32> + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index + // CHECK: %[[C0_I32_2:.*]] = arith.constant 0 : i32 + // CHECK: %[[C0_I32_3:.*]] = arith.constant 0 : i32 + // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32 + // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32 + // CHECK: %[[VAR13:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[VAR15:.*]] = vector.insert %[[VAR13]], %[[VAR14]] [0] : i64 into vector<4xi64> + // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[VAR17:.*]] = vector.insert %[[C16_I32]], %[[VAR16]] [2] : i32 into vector<8xi32> + // CHECK: %[[VAR18:.*]] = vector.insert %[[C8_I32]], %[[VAR17]] [3] : i32 into vector<8xi32> + // CHECK: %[[VAR19:.*]] = vector.insert %[[C0_I32_2]], %[[VAR18]] [4] : i32 into vector<8xi32> + // CHECK: %[[VAR20:.*]] = vector.insert %[[C0_I32_3]], %[[VAR19]] [5] : i32 into vector<8xi32> + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir new file mode 100644 index 0000000000000..15940fc4aca26 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s + +#sg_map_a_f16 = #xegpu.layout +#sg_map_b_f16 = #xegpu.layout +#sg_map_c_f32 = #xegpu.layout + +gpu.module @load_store_check { + //CHECK: func.func @dpas(%[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>) -> vector<8xf32> + func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> { + // Loads are checked in a separate test. + // CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = , types = } + // CHECK-SAME: : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> + %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32} + : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32> + return %d : vector<8xf32> + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/fence.mlir b/mlir/test/Conversion/XeGPUToXeVM/fence.mlir new file mode 100644 index 0000000000000..cedfcace398a6 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/fence.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s + +gpu.module @fence_check { + gpu.func @fence(%dst: memref<8x16xf32, 1>) kernel { + %tid_x = gpu.thread_id x + %tid_x_i32 = arith.index_cast %tid_x : index to i32 + %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32 + + // CHECK: xevm.memfence <{addrspace = #xevm.addr_space, scope = #xevm.mem_scope}> + xegpu.fence memory_kind = global, fence_scope = workgroup + %c0 = arith.constant 0 : index + memref.store %tid_x_f32, %dst[%c0, %c0] : memref<8x16xf32, 1> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir new file mode 100644 index 0000000000000..c692da632d458 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir @@ -0,0 +1,71 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s + +gpu.module @load_store_check { + gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel { + %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32> + %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32> + + // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 + // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> + // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64> + // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32> + // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32> + // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32> + // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32> + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + + + //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64> + //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64> + //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32> + //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32> + //CHECK: %[[LD_TILE_W:.*]] = arith.constant 0 : i32 + //CHECK: %[[LD_TILE_H:.*]] = arith.constant 0 : i32 + //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1> + //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32 + //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32 + //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]], + //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]] + //CHECK-SAME: <{cache_control = #xevm.load_cache_control, elem_size_in_bits = 32 : i32, + //CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false, + //CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32> + %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32> + //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32> + + %tid_x = gpu.thread_id x + %tid_x_i32 = arith.index_cast %tid_x : index to i32 + %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32 + //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32> + %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32> + + // CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 + // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> + // CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64> + // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32> + // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32> + // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32> + // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32> + %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + + //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64> + //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64> + //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32> + //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32> + //CHECK: %[[TILE_W:.*]] = arith.constant 0 : i32 + //CHECK: %[[TILE_H:.*]] = arith.constant 0 : i32 + //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1> + //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32 + //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32 + //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32> + //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]], + //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]] + //CHECK-SAME: <{cache_control = #xevm.store_cache_control, elem_size_in_bits = 32 : i32, + //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir new file mode 100644 index 0000000000000..f6d023307313a --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir @@ -0,0 +1,357 @@ +// RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm | FileCheck %s + +gpu.module @test { +// CHECK-LABEL: @load_gather_ui64_src_constant_offset +// CHECK-SAME: %[[ARG0:.*]]: ui64 +gpu.func @load_gather_ui64_src_constant_offset(%src: ui64) { + // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index + // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> + // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64 + %0 = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[CST_0:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> + %1 = arith.constant dense<1>: vector<1xi1> + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 + // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64 + %2 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex> + -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAR8:.*]] = scf.if %[[VAR4]] -> (vector<2xf32>) { + // CHECK: %[[VAR9:.*]] = llvm.load %[[VAR7]] {cache_control = #xevm.load_cache_control} + // CHECK-SAME: : !llvm.ptr<1> -> vector<2xf32> + // CHECK: scf.yield %[[VAR9]] : vector<2xf32> + // CHECK: } else { + // CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> + // CHECK: scf.yield %[[CST_1]] : vector<2xf32> + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr>, vector<1xi1> -> vector<2xf32> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @load_gather_memref_src_constant_offset +// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32> +gpu.func @load_gather_memref_src_constant_offset(%src: memref<256xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + %0 = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[CST_0:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> + %1 = arith.constant dense<1>: vector<1xi1> + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 + // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 + %2 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (f32) { + // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control} + // CHECK-SAME: : !llvm.ptr<1> -> vector<1xf32> + // CHECK: %[[VAR9:.*]] = vector.extract %[[VAR8]][0] : f32 from vector<1xf32> + // CHECK: scf.yield %[[VAR9]] : f32 + // CHECK: } else { + // CHECK: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> + // CHECK: %[[VAR8:.*]] = vector.extract %[[CST_1:.*]][0] : f32 from vector<1xf32> + // CHECK: scf.yield %[[VAR8]] : f32 + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1> -> vector<1xf32> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @load_gather_memref_src_value_offset +// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex> +gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: vector<1xindex>) { + // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> + %1 = arith.constant dense<1>: vector<1xi1> + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64 + // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64 + // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 + %2 = xegpu.create_tdesc %src, %offset : memref<256xf16>, vector<1xindex> + -> !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr> + // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (vector<8xf16>) { + // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control} + // CHECK-SAME: : !llvm.ptr<1> -> vector<8xf16> + // CHECK: scf.yield %[[VAR8]] : vector<8xf16> + // CHECK: } else { + // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<8xf16> + // CHECK: scf.yield %[[CST_0]] : vector<8xf16> + %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr>, vector<1xi1> -> vector<8xf16> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @load_gather_memref_src_load_offset +// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex> +gpu.func @load_gather_memref_src_load_offset(%src: memref<256xf16>, %offset1: vector<1xindex>, %offset2: vector<1xindex>) { + // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> + // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR4:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> + %1 = arith.constant dense<1>: vector<1xi1> + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index + // CHECK: %[[VAR5:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64 + // CHECK: %[[VAR6:.*]] = arith.muli %[[VAR3]], %[[C2_I64]] : i64 + // CHECK: %[[VAR7:.*]] = arith.addi %[[VAR5]], %[[VAR6]] : i64 + %2 = xegpu.create_tdesc %src, %offset1 : memref<256xf16>, vector<1xindex> + -> !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr> + // CHECK: %[[C2_I64_0:.*]] = arith.constant 2 : i64 + // CHECK: %[[VAR8:.*]] = arith.muli %[[VAR1]], %[[C2_I64_0]] : i64 + // CHECK: %[[VAR9:.*]] = arith.addi %[[VAR7]], %[[VAR8]] : i64 + // CHECK: %[[VAR10:.*]] = llvm.inttoptr %[[VAR9]] : i64 to !llvm.ptr<1> + // CHECK: %[[VAR11:.*]] = scf.if %[[VAR4]] -> (vector<8xf16>) { + // CHECK: %[[VAR12:.*]] = llvm.load %[[VAR10]] {cache_control = #xevm.load_cache_control} + // CHECK-SAME: : !llvm.ptr<1> -> vector<8xf16> + // CHECK: scf.yield %[[VAR12]] : vector<8xf16> + // CHECK: } else { + // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<8xf16> + // CHECK: scf.yield %[[CST_0]] : vector<8xf16> + %3 = xegpu.load %2[%offset2], %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr>, vector<1xindex>, vector<1xi1> -> vector<8xf16> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @store_scatter_ui64_src_constant_offset +// CHECK-SAME: %[[ARG0:.*]]: ui64 +gpu.func @store_scatter_ui64_src_constant_offset(%src: ui64) { + // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index + // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> + // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64 + %0 = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[CST_0:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> + %1 = arith.constant dense<1>: vector<1xi1> + // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900000e+00> : vector<2xf32> + %2 = arith.constant dense<2.9>: vector<2xf32> + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 + // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64 + %3 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex> + -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1> + // CHECK: scf.if %[[VAR4]] { + // CHECK: llvm.store %[[CST_1]], %[[VAR7]] {cache_control = #xevm.store_cache_control} + // CHECK-SAME: : vector<2xf32>, !llvm.ptr<1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<2xf32>, !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr>, vector<1xi1> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @store_scatter_memref_src_constant_offset +// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32> +gpu.func @store_scatter_memref_src_constant_offset(%src: memref<256xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + %0 = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[CST_0:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> + %1 = arith.constant dense<1>: vector<1xi1> + // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900390e+00> : vector<2xf16> + %2 = arith.constant dense<2.9>: vector<2xf16> + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64 + // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64 + // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 + %3 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr> + // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> + // CHECK: scf.if %[[VAR2]] { + // CHECK: llvm.store %[[CST_1]], %[[VAR6]] {cache_control = #xevm.store_cache_control} + // CHECK-SAME: : vector<2xf16>, !llvm.ptr<1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<2xf16>, !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr>, vector<1xi1> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @store_scatter_memref_src_value_offset +// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex> +gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) { + // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> + %1 = arith.constant dense<1>: vector<1xi1> + // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32> + %2 = arith.constant dense<2.9>: vector<1xf32> + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 + // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 + %3 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> + // CHECK: scf.if %[[VAR2]] { + // CHECK: llvm.store %[[CST_0]], %[[VAR6]] {cache_control = #xevm.store_cache_control} + // CHECK-SAME: : vector<1xf32>, !llvm.ptr<1> + xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @store_scatter_memref_src_store_offset +// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex> +gpu.func @store_scatter_memref_src_store_offset(%src: memref<256xf32>, %offset: vector<1xindex>, %offset2: vector<1xindex>) { + // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> + // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR4:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> + %1 = arith.constant dense<1>: vector<1xi1> + // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32> + %2 = arith.constant dense<2.9>: vector<1xf32> + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index + // CHECK: %[[VAR5:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR6:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 + // CHECK: %[[VAR7:.*]] = arith.addi %[[VAR5]], %[[VAR6]] : i64 + %3 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + // CHECK: %[[C4_I64_1:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR8:.*]] = arith.muli %[[VAR1]], %[[C4_I64_1]] : i64 + // CHECK: %[[VAR9:.*]] = arith.addi %[[VAR7]], %[[VAR8]] : i64 + // CHECK: %[[VAR10:.*]] = llvm.inttoptr %[[VAR9]] : i64 to !llvm.ptr<1> + // CHECK: scf.if %[[VAR4]] { + // CHECK: llvm.store %[[CST_0]], %[[VAR10]] {cache_control = #xevm.store_cache_control} + // CHECK-SAME: : vector<1xf32>, !llvm.ptr<1> + xegpu.store %2, %3[%offset2], %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xindex>, vector<1xi1> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @prefetch_ui64_src_constant_offset +// CHECK-SAME: %[[ARG0:.*]]: ui64 +gpu.func @prefetch_ui64_src_constant_offset(%src: ui64) { + // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index + // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> + // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64 + %0 = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 + // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR1]], %[[VAR4]] : i64 + %1 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex> + -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> + // CHECK: xevm.prefetch %[[VAR6]] <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) + xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @prefetch_memref_src_constant_offset +// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32> +gpu.func @prefetch_memref_src_constant_offset(%src: memref<256xf32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + %0 = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index + // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 + // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64 + %1 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1> + // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) + xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @prefetch_memref_src_value_offset +// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex> +gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) { + // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index + // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 + // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64 + %1 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1> + // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) + xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + gpu.return +} +} +// ----- + +gpu.module @test { +// CHECK-LABEL: @prefetch_memref_src_prefetch_offset +// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex> +gpu.func @prefetch_memref_src_prefetch_offset(%src: memref<256xf32>, %offset: vector<1xindex>, %offset2: vector<1xindex>) { + // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> + // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64 + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index + // CHECK: %[[VAR4:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 + // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR4]], %[[VAR5]] : i64 + %1 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> + // CHECK: %[[C4_I64_0:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR7:.*]] = arith.muli %[[VAR1]], %[[C4_I64_0]] : i64 + // CHECK: %[[VAR8:.*]] = arith.addi %[[VAR6]], %[[VAR7]] : i64 + // CHECK: %[[VAR9:.*]] = llvm.inttoptr %[[VAR8]] : i64 to !llvm.ptr<1> + // CHECK: xevm.prefetch %[[VAR9]] <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) + xegpu.prefetch %1[%offset2] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr>, vector<1xindex> + gpu.return +} +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir new file mode 100644 index 0000000000000..8513b4f9857fb --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm -split-input-file %s | FileCheck %s + +gpu.module @fence_check { + gpu.func @fence(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel { + %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32> + %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32> + + // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64 + // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64> + // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64> + // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32> + // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32> + // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32> + // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32> + %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, + #xegpu.block_tdesc_attr, #xegpu.layout> + + //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64> + //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64> + //CHECK: %[[PREF_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32> + //CHECK: %[[PREF_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32> + //CHECK: %[[PREF_TILE_W:.*]] = arith.constant 0 : i32 + //CHECK: %[[PREF_TILE_H:.*]] = arith.constant 0 : i32 + //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1> + //CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32 + //CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[PREF_BASE_W]], %[[PREF_SIZEOF_F32]] : i32 + //CHECK: xevm.blockprefetch2d %[[PREF_LLVMPTR]], %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_BASE_H]], + //CHECK-SAME: %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_TILE_W]], %[[PREF_TILE_H]] + //CHECK-SAME: <{cache_control = #xevm.load_cache_control, elem_size_in_bits = 32 : i32, + //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}> + //CHECK-SAME: : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + xegpu.prefetch_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr, + #xegpu.layout> + + gpu.return + } +} + diff --git a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir new file mode 100644 index 0000000000000..e9d7fd4cf40a6 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s + +gpu.module @update_offset { + // CHECK-LABEL: gpu.func @update_offset + // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32> + gpu.func @update_offset(%src: memref<128xf32>) kernel { + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + %offset = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> + // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index + // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 + // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64 + %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + // CHECK: %[[C4_I64_0:.*]] = arith.constant 4 : i64 + // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR1]], %[[C4_I64_0]] : i64 + // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR4]], %[[VAR5]] : i64 + %new_tdesc = xegpu.update_offset %src_tdesc, %offset : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + , vector<1xindex> + gpu.return + } +} From 61bee9f1feb386f7e8402424c8b2afba75b16062 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Wed, 20 Aug 2025 15:25:41 +0000 Subject: [PATCH 02/18] Apply clang format. --- .../lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 380409afbc62e..32983152ef5bd 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -558,10 +558,10 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); Value loaded = LLVM::LoadOp::create(rewriter, loc, srcOrDstFlatVecTy, basePtrLLVM); - loaded.getDefiningOp()->setAttr("cache_control", - xevm::LoadCacheControlAttr::get( - ctxt, translateLoadXeGPUCacheHint( - op.getL1Hint(), op.getL3Hint()))); + loaded.getDefiningOp()->setAttr( + "cache_control", xevm::LoadCacheControlAttr::get( + ctxt, translateLoadXeGPUCacheHint( + op.getL1Hint(), op.getL3Hint()))); if (srcOrDstVecTy != srcOrDstFlatVecTy) { loaded = vector::ShapeCastOp::create(rewriter, loc, srcOrDstVecTy, loaded); @@ -588,12 +588,12 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { srcFlatVec = vector::ShapeCastOp::create(rewriter, loc, srcOrDstFlatVecTy, srcFlatVec); } - auto storeOp = LLVM::StoreOp::create(rewriter, loc, srcFlatVec, basePtrLLVM); + auto storeOp = + LLVM::StoreOp::create(rewriter, loc, srcFlatVec, basePtrLLVM); storeOp.getOperation()->setAttr( - "cache_control", - xevm::StoreCacheControlAttr::get(ctxt, - translateStoreXeGPUCacheHint( - op.getL1Hint(), op.getL3Hint()))); + "cache_control", xevm::StoreCacheControlAttr::get( + ctxt, translateStoreXeGPUCacheHint( + op.getL1Hint(), op.getL3Hint()))); rewriter.eraseOp(op); } return success(); From 4aa4cb29accaa3242cbcdb74155ed0a4eb6ec536 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Wed, 20 Aug 2025 15:29:01 +0000 Subject: [PATCH 03/18] Remove commented out code. --- mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 32983152ef5bd..89f40c22e7a68 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -380,8 +380,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); auto elemType = tdescTy.getElementType(); auto elemBitSize = elemType.getIntOrFloatBitWidth(); - // auto elemBitSizeAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), - // elemBitSize); Value elemByteSize = arith::ConstantIntOp::create( rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); Value surfaceW = @@ -695,8 +693,6 @@ class DpasToXeVMPattern : public OpConversionPattern { }; xevm::ElemType precATy = encodePrecision(aTy.getElementType()); xevm::ElemType precBTy = encodePrecision(bTy.getElementType()); - // auto precA = xevm::ElemTypeAttr::get(ctxt, precATy); - // auto precB = xevm::ElemTypeAttr::get(ctxt, precBTy); Value c = op.getAcc(); if (!c) { auto elementTy = resultType.getElementType(); @@ -714,8 +710,6 @@ class DpasToXeVMPattern : public OpConversionPattern { auto cvecty = cast(c.getType()); xevm::ElemType precCTy = encodePrecision(cvecty.getElementType()); xevm::ElemType precDTy = encodePrecision(resultType.getElementType()); - // auto precC = xevm::ElemTypeAttr::get(ctxt, precCTy); - // auto precD = xevm::ElemTypeAttr::get(ctxt, precDTy); VectorType cNty = VectorType::get(cvecty.getNumElements(), cvecty.getElementType()); if (cvecty != cNty) From 687e831902974da829d692d415df14b0e50a25b2 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Wed, 20 Aug 2025 17:59:49 +0000 Subject: [PATCH 04/18] Remove dead code. --- mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 89f40c22e7a68..776380974c549 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -56,19 +56,6 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { llvm_unreachable("Unknown XeGPU memory space."); } -template -std::tuple checkAllLinear(SmallVector denseAttr) { - assert(!denseAttr.empty()); - const int32_t intercept{static_cast(denseAttr[0])}; - if (denseAttr.size() < 2) - return {true, 0, intercept}; - const T slope{denseAttr[1] - denseAttr[0]}; - for (size_t i = 1; i < denseAttr.size(); ++i) - if (denseAttr[i] - denseAttr[i - 1] != slope) - return {false, 0, 0}; - return {true, static_cast(slope), intercept}; -} - VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) { auto elemType = currentVecType.getElementType(); auto currentBitWidth = elemType.getIntOrFloatBitWidth(); From e240e47a1c3731ba2beb8af5e5e8f08f858a84c1 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Thu, 21 Aug 2025 22:41:37 +0000 Subject: [PATCH 05/18] Temp save. --- .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 128 +++++++++++++----- .../XeGPUToXeVM/create_nd_tdesc.mlir | 4 +- .../XeGPUToXeVM/loadstoreprefetch.mlir | 4 +- .../XeGPUToXeVM/materializecast.mlir | 49 +++++++ .../Conversion/XeGPUToXeVM/update_offset.mlir | 6 +- 5 files changed, 150 insertions(+), 41 deletions(-) create mode 100644 mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 776380974c549..4ff5321c1d9d2 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" @@ -426,18 +427,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { } }; -template < - typename OpType, - typename = std::enable_if_t::value>> -int64_t getElemByteSize(OpType op) { - // Get the element byte size from the tensor descriptor. - auto elemBitWidth = - op.getTensorDesc().getType().getElementType().getIntOrFloatBitWidth(); - return elemBitWidth / 8; -} - // Add a builder that creates // offset * elemByteSize + baseAddr auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc, @@ -456,23 +445,23 @@ class CreateDescToXeVMPattern LogicalResult matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + auto eTy = op.getTensorDescType().getElementType(); + if (eTy.getIntOrFloatBitWidth() % 8 != 0) { + return rewriter.notifyMatchFailure(op, + "Expected element type bit width to be multiple of 8."); + } auto loc = op.getLoc(); + // offsets are provided as scalar i64 by type converter. auto offsets = adaptor.getOffsets(); - // Source type can be a 1D memref or ui64 - // Using "op" instead of "adaptor" since we want to access memref type - // instead of LLVM struct type. - auto memrefTy = dyn_cast(op.getSource().getType()); - Value subGroupAddr; - if (memrefTy) { - subGroupAddr = memref::ExtractAlignedPointerAsIndexOp::create( - rewriter, loc, op.getSource()); - subGroupAddr = arith::IndexCastUIOp::create( - rewriter, loc, rewriter.getI64Type(), subGroupAddr); - } else { - subGroupAddr = adaptor.getSource(); - } + // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32). + // But type converter will convert them to integer types. + Value addr = adaptor.getSource(); + // ui32 or i32 are passed as i32 so they need to be casted to i64. + if (addr.getType() != rewriter.getI64Type()) + addr = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI64Type(), addr); auto laneAddr = - addOffset(rewriter, loc, subGroupAddr, offsets, getElemByteSize(op)); + addOffset(rewriter, loc, addr, offsets, getElemByteSize(op)); rewriter.replaceOp(op, laneAddr); return success(); } @@ -485,11 +474,18 @@ class UpdateOffsetToXeVMPattern matchAndRewrite(xegpu::UpdateOffsetOp op, xegpu::UpdateOffsetOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + auto eTy = op.getTensorDescType().getElementType(); + if (eTy.getIntOrFloatBitWidth() % 8 != 0) { + return rewriter.notifyMatchFailure(op, + "Expected element type bit width to be multiple of 8."); + } auto loc = op.getLoc(); - Value newOffsetForLane = + // scatter descriptor is provided as scalar i64 by type converter. + // offsets are provided as scalar i64 by type converter. + Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(), adaptor.getOffsets(), getElemByteSize(op)); - rewriter.replaceOp(op, newOffsetForLane); + rewriter.replaceOp(op, newOffset); return success(); } }; @@ -505,19 +501,38 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); auto tdescTy = op.getTensorDescType(); - auto ptrTypeLLVM = LLVM::LLVMPointerType::get( - ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); + LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global)); + if (tdescTy) + ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); Value basePtrI64; if constexpr (std::is_same_v) { basePtrI64 = adaptor.getSource(); + if (auto memRefTy = dyn_cast(op.getSource().getType())) { + auto addrSpace = memRefTy.getMemorySpaceAsInt(); + if (addrSpace != 0) + ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace); + } } else { basePtrI64 = adaptor.getDest(); + if (auto memRefTy = dyn_cast(op.getDest().getType())) { + auto addrSpace = memRefTy.getMemorySpaceAsInt(); + if (addrSpace != 0) + ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace); + } } + if (basePtrI64.getType() != rewriter.getI64Type()) { + basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), + basePtrI64); + } + basePtrI64.dump(); Value offsets = adaptor.getOffsets(); + offsets.dump(); Value mask = adaptor.getMask(); + mask.dump(); if (offsets) { - VectorType offsetsVecTy = dyn_cast(offsets.getType()); - if (offsetsVecTy) { + if (dyn_cast(offsets.getType())){ // Offset needs be scalar. return rewriter.notifyMatchFailure(op, "Expected offsets to be a scalar."); @@ -526,8 +541,10 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op)); } } + basePtrI64.dump(); Value basePtrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); + basePtrLLVM.dump(); VectorType srcOrDstVecTy = op.getValueType(); VectorType srcOrDstFlatVecTy = VectorType::get( srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType()); @@ -597,6 +614,10 @@ class PrefetchToXeVMPattern : public OpConversionPattern { ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); Value basePtrI64 = adaptor.getSource(); Value offsets = adaptor.getOffsets(); + if (basePtrI64.getType() != rewriter.getI64Type()) { + basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), + basePtrI64); + } if (offsets) { VectorType offsetsVecTy = dyn_cast(offsets.getType()); if (offsetsVecTy) { @@ -836,6 +857,26 @@ struct ConvertXeGPUToXeVMPass auto i32Type = IntegerType::get(&getContext(), 32); return VectorType::get(8, i32Type); }); + typeConverter.addConversion([&](MemRefType type) -> Type { + // Convert MemRefType to i64 type. + return IntegerType::get(&getContext(), 64); + }); + + auto memrefMaterializationCast = [](OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (auto memrefTy = dyn_cast(input.getType())) { + + Value addr = memref::ExtractAlignedPointerAsIndexOp::create( + builder, loc, input); + return arith::IndexCastUIOp::create(builder, loc, type, + addr).getResult(); + } + return {}; + }; auto ui64MaterializationCast = [](OpBuilder &builder, Type type, ValueRange inputs, @@ -847,7 +888,22 @@ struct ConvertXeGPUToXeVMPass Value cast = index::CastUOp::create(builder, loc, builder.getIndexType(), input) .getResult(); - return arith::IndexCastOp::create(builder, loc, type, cast).getResult(); + return arith::IndexCastUIOp::create(builder, loc, type, cast).getResult(); + } + return {}; + }; + + auto ui32MaterializationCast = [](OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (input.getType() == builder.getIntegerType(32, false)) { + Value cast = + index::CastUOp::create(builder, loc, builder.getIndexType(), input) + .getResult(); + return arith::IndexCastUIOp::create(builder, loc, type, cast).getResult(); } return {}; }; @@ -864,15 +920,19 @@ struct ConvertXeGPUToXeVMPass Value cast = vector::ExtractOp::create(builder, loc, input, 0).getResult(); if (vecTy.getElementType() == builder.getIndexType()) - cast = arith::IndexCastOp::create(builder, loc, type, cast) + cast = arith::IndexCastUIOp::create(builder, loc, type, cast) .getResult(); return cast; } } return {}; }; + typeConverter.addSourceMaterialization(memrefMaterializationCast); typeConverter.addSourceMaterialization(ui64MaterializationCast); + typeConverter.addSourceMaterialization(ui32MaterializationCast); typeConverter.addSourceMaterialization(vector1DMaterializationCast); + typeConverter.addTargetMaterialization(memrefMaterializationCast); + typeConverter.addTargetMaterialization(ui32MaterializationCast); typeConverter.addTargetMaterialization(ui64MaterializationCast); typeConverter.addTargetMaterialization(vector1DMaterializationCast); ConversionTarget target(getContext()); diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir index 4fba920f023c4..7f5e3527a1594 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir @@ -6,8 +6,8 @@ gpu.module @create_nd_tdesc { // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index, %stride1: index, %stride2: index) kernel { - // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index - // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32 // CHECK: %[[C0_I32_0:.*]] = arith.constant 0 : i32 diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir index f6d023307313a..825a4d6368863 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir @@ -5,10 +5,10 @@ gpu.module @test { // CHECK-SAME: %[[ARG0:.*]]: ui64 gpu.func @load_gather_ui64_src_constant_offset(%src: ui64) { // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index - // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64 + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64 %0 = arith.constant dense<0> : vector<1xindex> // CHECK: %[[CST_0:.*]] = arith.constant dense : vector<1xi1> // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir new file mode 100644 index 0000000000000..a7ae4d9b7e4d2 --- /dev/null +++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir @@ -0,0 +1,49 @@ +// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s + +gpu.module @materializecast { + // CHECK-LABEL: gpu.func @materialize_memref + // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32> + gpu.func @materialize_memref(%src: memref<128xf32>) kernel { + // CHECK: XXX + %offset = arith.constant dense<0> : vector<1xindex> + %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + gpu.return + } + // CHECK-LABEL: gpu.func @materialize_ui64 + // CHECK-SAME: %[[ARG0:.*]]: ui64 + gpu.func @materialize_ui64(%src: ui64) kernel { + // CHECK: XXX + %offset = arith.constant dense<0> : vector<1xindex> + %src_tdesc = xegpu.create_tdesc %src, %offset : ui64, vector<1xindex> + -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + gpu.return + } + // CHECK-LABEL: gpu.func @materialize_ui32 + // CHECK-SAME: %[[ARG0:.*]]: ui32 + gpu.func @materialize_ui32(%src: ui32) kernel { + %offset = arith.constant dense<0> : vector<1xindex> + //%src_tdesc = xegpu.create_tdesc %src, %offset : ui32, vector<1xindex> + // -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + gpu.return + } + // CHECK-LABEL: gpu.func @materialize_single_index_vector + // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32> + gpu.func @materialize_single_index_vector(%src: memref<128xf32>) kernel { + // CHECK: XXX + %offset = arith.constant dense<0> : vector<1xindex> + %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex> + -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + gpu.return + } + // CHECK-LABEL: gpu.func @materialize_single_elem_vector + // CHECK-SAME: %[[ARG0:.*]]: vector<1xi1> + gpu.func @materialize_single_elem_vector(%src: memref<128xf32>) kernel { + // CHECK: XXX + %mask = arith.constant dense<1>: vector<1xi1> + %offset = arith.constant dense<0> : vector<1xindex> + %0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<1x8xf32> + gpu.return + } +} diff --git a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir index e9d7fd4cf40a6..6e59414c62582 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir @@ -4,12 +4,12 @@ gpu.module @update_offset { // CHECK-LABEL: gpu.func @update_offset // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32> gpu.func @update_offset(%src: memref<128xf32>) kernel { + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index + // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> %offset = arith.constant dense<0> : vector<1xindex> // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index - // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64 From 8e507ec303c37524d793d0dcef382922a185257f Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Mon, 25 Aug 2025 18:32:56 +0000 Subject: [PATCH 06/18] Adjust to latest XeGPU dialect update. --- .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 198 ++++++++++++------ .../XeGPUToXeVM/loadstoreprefetch.mlir | 142 ++----------- .../XeGPUToXeVM/materializecast.mlir | 39 +++- 3 files changed, 185 insertions(+), 194 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 4ff5321c1d9d2..6cfa8ac1f8fce 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -446,9 +446,10 @@ class CreateDescToXeVMPattern matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto eTy = op.getTensorDescType().getElementType(); - if (eTy.getIntOrFloatBitWidth() % 8 != 0) { - return rewriter.notifyMatchFailure(op, - "Expected element type bit width to be multiple of 8."); + auto eBw = eTy.getIntOrFloatBitWidth(); + if (eBw % 8 != 0) { + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); } auto loc = op.getLoc(); // offsets are provided as scalar i64 by type converter. @@ -458,10 +459,8 @@ class CreateDescToXeVMPattern Value addr = adaptor.getSource(); // ui32 or i32 are passed as i32 so they need to be casted to i64. if (addr.getType() != rewriter.getI64Type()) - addr = arith::IndexCastUIOp::create( - rewriter, loc, rewriter.getI64Type(), addr); - auto laneAddr = - addOffset(rewriter, loc, addr, offsets, getElemByteSize(op)); + addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr); + auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8); rewriter.replaceOp(op, laneAddr); return success(); } @@ -475,16 +474,16 @@ class UpdateOffsetToXeVMPattern xegpu::UpdateOffsetOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto eTy = op.getTensorDescType().getElementType(); - if (eTy.getIntOrFloatBitWidth() % 8 != 0) { - return rewriter.notifyMatchFailure(op, - "Expected element type bit width to be multiple of 8."); + auto eBw = eTy.getIntOrFloatBitWidth(); + if (eBw % 8 != 0) { + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); } auto loc = op.getLoc(); // scatter descriptor is provided as scalar i64 by type converter. // offsets are provided as scalar i64 by type converter. - Value newOffset = - addOffset(rewriter, loc, adaptor.getTensorDesc(), adaptor.getOffsets(), - getElemByteSize(op)); + Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(), + adaptor.getOffsets(), eBw / 8); rewriter.replaceOp(op, newOffset); return success(); } @@ -501,12 +500,35 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); auto tdescTy = op.getTensorDescType(); - LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( - ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global)); - if (tdescTy) - ptrTypeLLVM = LLVM::LLVMPointerType::get( - ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); Value basePtrI64; + // Load result or Store valye Type can be vector or scalar. + Type valOrResTy; + if constexpr (std::is_same_v) { + valOrResTy = op.getResult().getType(); + } else { + valOrResTy = adaptor.getValue().getType(); + } + VectorType valOrResVecTy = dyn_cast(valOrResTy); + bool hasScalarVal = !valOrResVecTy; + int64_t elemBitWidth = + hasScalarVal ? valOrResTy.getIntOrFloatBitWidth() + : valOrResVecTy.getElementType().getIntOrFloatBitWidth(); + // Element type must be multiple of 8 bits. + if (elemBitWidth % 8 != 0) { + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + } + int64_t elemByteSize = elemBitWidth / 8; + // Default memory space is global. + LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global)); + // If tensor descriptor is available, we use its memory space. + if (tdescTy) { + ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); + } + // Base pointer can come from source (load) or dest (store). + // If they are memrefs, we use their memory space. if constexpr (std::is_same_v) { basePtrI64 = adaptor.getSource(); if (auto memRefTy = dyn_cast(op.getSource().getType())) { @@ -522,76 +544,79 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace); } } + // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed. if (basePtrI64.getType() != rewriter.getI64Type()) { - basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), - basePtrI64); + basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), + basePtrI64); } - basePtrI64.dump(); Value offsets = adaptor.getOffsets(); - offsets.dump(); Value mask = adaptor.getMask(); - mask.dump(); if (offsets) { - if (dyn_cast(offsets.getType())){ - // Offset needs be scalar. + if (dyn_cast(offsets.getType())) { + // Offset needs be scalar. Single element vector is converted to scalar + // by type converter. return rewriter.notifyMatchFailure(op, "Expected offsets to be a scalar."); } else { + // If offsets are provided, we add them to the base pointer. + // Offsets are in number of elements, we need to multiply by + // element byte size. basePtrI64 = - addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op)); + addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); } } - basePtrI64.dump(); + // Convert base pointer (i64) to LLVM pointer type. Value basePtrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); - basePtrLLVM.dump(); - VectorType srcOrDstVecTy = op.getValueType(); - VectorType srcOrDstFlatVecTy = VectorType::get( - srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType()); + Value maskForLane; VectorType maskVecTy = dyn_cast(mask.getType()); if (maskVecTy) { + // Mask needs be scalar. Single element vector is converted to scalar by + // type converter. return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar."); - } else + } else { maskForLane = mask; + } if constexpr (std::is_same_v) { - scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {srcOrDstVecTy}, + scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy}, maskForLane, true, true); + // If mask is true,- then clause - load from memory and yield. rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + if (!hasScalarVal) + valOrResTy = VectorType::get({valOrResVecTy.getNumElements()}, + valOrResVecTy.getElementType()); Value loaded = - LLVM::LoadOp::create(rewriter, loc, srcOrDstFlatVecTy, basePtrLLVM); + LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM); + // Set cache control attribute on the load operation. loaded.getDefiningOp()->setAttr( "cache_control", xevm::LoadCacheControlAttr::get( ctxt, translateLoadXeGPUCacheHint( op.getL1Hint(), op.getL3Hint()))); - if (srcOrDstVecTy != srcOrDstFlatVecTy) { - loaded = - vector::ShapeCastOp::create(rewriter, loc, srcOrDstVecTy, loaded); - } scf::YieldOp::create(rewriter, loc, ValueRange{loaded}); rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - // If mask is false, we yield a vector of zeros. - auto eTy = srcOrDstVecTy.getElementType(); - loaded = arith::ConstantOp::create( - rewriter, loc, - eTy.isFloat() - ? DenseElementsAttr::get(srcOrDstVecTy, FloatAttr::get(eTy, 0.0)) - : DenseElementsAttr::get(srcOrDstVecTy, - IntegerAttr::get(eTy, 0))); + // If mask is false - else clause -yield a vector of zeros. + auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType(); + TypedAttr eVal; + if (eTy.isFloat()) + eVal = FloatAttr::get(eTy, 0.0); + else + eVal = IntegerAttr::get(eTy, 0); + if (hasScalarVal) + loaded = arith::ConstantOp::create(rewriter, loc, eVal); + else + loaded = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal)); scf::YieldOp::create(rewriter, loc, ValueRange{loaded}); rewriter.replaceOp(op, ifOp.getResult(0)); } else { + // if mask is true, perform the store. scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false); auto body = ifOp.getBody(); rewriter.setInsertionPointToStart(body); - VectorType valTy = op.getValue().getType(); - Value srcFlatVec = op.getValue(); - if (valTy != srcOrDstFlatVecTy) { - srcFlatVec = vector::ShapeCastOp::create(rewriter, loc, - srcOrDstFlatVecTy, srcFlatVec); - } auto storeOp = - LLVM::StoreOp::create(rewriter, loc, srcFlatVec, basePtrLLVM); + LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM); + // Set cache control attribute on the store operation. storeOp.getOperation()->setAttr( "cache_control", xevm::StoreCacheControlAttr::get( ctxt, translateStoreXeGPUCacheHint( @@ -610,14 +635,13 @@ class PrefetchToXeVMPattern : public OpConversionPattern { auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); auto tdescTy = op.getTensorDescType(); - auto ptrTypeLLVM = LLVM::LLVMPointerType::get( - ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); Value basePtrI64 = adaptor.getSource(); - Value offsets = adaptor.getOffsets(); + // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed. if (basePtrI64.getType() != rewriter.getI64Type()) { - basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), - basePtrI64); + basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), + basePtrI64); } + Value offsets = adaptor.getOffsets(); if (offsets) { VectorType offsetsVecTy = dyn_cast(offsets.getType()); if (offsetsVecTy) { @@ -625,12 +649,50 @@ class PrefetchToXeVMPattern : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Expected offsets to be a scalar."); } else { + int64_t elemBitWidth{0}; + int64_t elemByteSize; + // Element byte size can come from three sources: + if (tdescTy) { + // If tensor descriptor is available, we use its element type to + // determine element byte size. + elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth(); + } else if (auto memRefTy = dyn_cast(op.getSourceType())) { + // If memref is available, we use its element type to + // determine element byte size. + elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth(); + } else { + // Otherwise, we use the provided offset byte alignment. + elemByteSize = *op.getOffsetAlignByte(); + } + if (elemBitWidth != 0) { + if (elemBitWidth % 8 != 0) { + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + } + elemByteSize = elemBitWidth / 8; + } basePtrI64 = - addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op)); + addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); } } + // Default memory space is global. + LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global)); + // If tensor descriptor is available, we use its memory space. + if (tdescTy) { + ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); + } + // If source is a memref, we use its memory space. + if (auto memRefTy = dyn_cast(op.getSource().getType())) { + auto addrSpace = memRefTy.getMemorySpaceAsInt(); + if (addrSpace != 0) + ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace); + } + // Convert base pointer (i64) to LLVM pointer type. Value ptrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); + // Create the prefetch op with cache control attribute. xevm::PrefetchOp::create( rewriter, loc, ptrLLVM, xevm::LoadCacheControlAttr::get( @@ -863,17 +925,17 @@ struct ConvertXeGPUToXeVMPass }); auto memrefMaterializationCast = [](OpBuilder &builder, Type type, - ValueRange inputs, - Location loc) -> Value { + ValueRange inputs, + Location loc) -> Value { if (inputs.size() != 1) return {}; auto input = inputs.front(); if (auto memrefTy = dyn_cast(input.getType())) { - Value addr = memref::ExtractAlignedPointerAsIndexOp::create( - builder, loc, input); - return arith::IndexCastUIOp::create(builder, loc, type, - addr).getResult(); + Value addr = + memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input); + return arith::IndexCastUIOp::create(builder, loc, type, addr) + .getResult(); } return {}; }; @@ -888,7 +950,8 @@ struct ConvertXeGPUToXeVMPass Value cast = index::CastUOp::create(builder, loc, builder.getIndexType(), input) .getResult(); - return arith::IndexCastUIOp::create(builder, loc, type, cast).getResult(); + return arith::IndexCastUIOp::create(builder, loc, type, cast) + .getResult(); } return {}; }; @@ -903,7 +966,8 @@ struct ConvertXeGPUToXeVMPass Value cast = index::CastUOp::create(builder, loc, builder.getIndexType(), input) .getResult(); - return arith::IndexCastUIOp::create(builder, loc, type, cast).getResult(); + return arith::IndexCastUIOp::create(builder, loc, type, cast) + .getResult(); } return {}; }; diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir index 825a4d6368863..0f67dc290689b 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir @@ -37,15 +37,15 @@ gpu.module @test { // CHECK-LABEL: @load_gather_memref_src_constant_offset // CHECK-SAME: %[[ARG0:.*]]: memref<256xf32> gpu.func @load_gather_memref_src_constant_offset(%src: memref<256xf32>) { + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 %0 = arith.constant dense<0> : vector<1xindex> // CHECK: %[[CST_0:.*]] = arith.constant dense : vector<1xi1> // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> %1 = arith.constant dense<1>: vector<1xi1> - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 @@ -73,12 +73,12 @@ gpu.module @test { // CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex> gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: vector<1xindex>) { // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> - // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense : vector<1xi1> // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> %1 = arith.constant dense<1>: vector<1xi1> - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64 // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64 // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 @@ -99,51 +99,15 @@ gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: ve } // ----- -gpu.module @test { -// CHECK-LABEL: @load_gather_memref_src_load_offset -// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex> -gpu.func @load_gather_memref_src_load_offset(%src: memref<256xf16>, %offset1: vector<1xindex>, %offset2: vector<1xindex>) { - // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex> - // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 - // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> - // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense : vector<1xi1> - // CHECK: %[[VAR4:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> - %1 = arith.constant dense<1>: vector<1xi1> - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index - // CHECK: %[[VAR5:.*]] = arith.index_castui %[[INTPTR]] : index to i64 - // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64 - // CHECK: %[[VAR6:.*]] = arith.muli %[[VAR3]], %[[C2_I64]] : i64 - // CHECK: %[[VAR7:.*]] = arith.addi %[[VAR5]], %[[VAR6]] : i64 - %2 = xegpu.create_tdesc %src, %offset1 : memref<256xf16>, vector<1xindex> - -> !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr> - // CHECK: %[[C2_I64_0:.*]] = arith.constant 2 : i64 - // CHECK: %[[VAR8:.*]] = arith.muli %[[VAR1]], %[[C2_I64_0]] : i64 - // CHECK: %[[VAR9:.*]] = arith.addi %[[VAR7]], %[[VAR8]] : i64 - // CHECK: %[[VAR10:.*]] = llvm.inttoptr %[[VAR9]] : i64 to !llvm.ptr<1> - // CHECK: %[[VAR11:.*]] = scf.if %[[VAR4]] -> (vector<8xf16>) { - // CHECK: %[[VAR12:.*]] = llvm.load %[[VAR10]] {cache_control = #xevm.load_cache_control} - // CHECK-SAME: : !llvm.ptr<1> -> vector<8xf16> - // CHECK: scf.yield %[[VAR12]] : vector<8xf16> - // CHECK: } else { - // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<8xf16> - // CHECK: scf.yield %[[CST_0]] : vector<8xf16> - %3 = xegpu.load %2[%offset2], %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr>, vector<1xindex>, vector<1xi1> -> vector<8xf16> - gpu.return -} -} -// ----- - gpu.module @test { // CHECK-LABEL: @store_scatter_ui64_src_constant_offset // CHECK-SAME: %[[ARG0:.*]]: ui64 gpu.func @store_scatter_ui64_src_constant_offset(%src: ui64) { // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index - // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64 + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64 %0 = arith.constant dense<0> : vector<1xindex> // CHECK: %[[CST_0:.*]] = arith.constant dense : vector<1xi1> // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> @@ -170,17 +134,17 @@ gpu.module @test { // CHECK-LABEL: @store_scatter_memref_src_constant_offset // CHECK-SAME: %[[ARG0:.*]]: memref<256xf32> gpu.func @store_scatter_memref_src_constant_offset(%src: memref<256xf32>) { + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 %0 = arith.constant dense<0> : vector<1xindex> // CHECK: %[[CST_0:.*]] = arith.constant dense : vector<1xi1> // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1> %1 = arith.constant dense<1>: vector<1xi1> // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900390e+00> : vector<2xf16> %2 = arith.constant dense<2.9>: vector<2xf16> - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64 // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64 // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 @@ -202,14 +166,15 @@ gpu.module @test { // CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex> gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) { // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> - // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense : vector<1xi1> // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> %1 = arith.constant dense<1>: vector<1xi1> // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32> + // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f32 from vector<1xf32> %2 = arith.constant dense<2.9>: vector<1xf32> - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index - // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64 @@ -217,8 +182,8 @@ gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset: -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1> // CHECK: scf.if %[[VAR2]] { - // CHECK: llvm.store %[[CST_0]], %[[VAR6]] {cache_control = #xevm.store_cache_control} - // CHECK-SAME: : vector<1xf32>, !llvm.ptr<1> + // CHECK: llvm.store %[[VAR7]], %[[VAR6]] {cache_control = #xevm.store_cache_control} + // CHECK-SAME: : f32, !llvm.ptr<1> xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1> gpu.return @@ -226,49 +191,15 @@ gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset: } // ----- -gpu.module @test { -// CHECK-LABEL: @store_scatter_memref_src_store_offset -// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex> -gpu.func @store_scatter_memref_src_store_offset(%src: memref<256xf32>, %offset: vector<1xindex>, %offset2: vector<1xindex>) { - // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex> - // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 - // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> - // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64 - // CHECK: %[[CST:.*]] = arith.constant dense : vector<1xi1> - // CHECK: %[[VAR4:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> - %1 = arith.constant dense<1>: vector<1xi1> - // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32> - %2 = arith.constant dense<2.9>: vector<1xf32> - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index - // CHECK: %[[VAR5:.*]] = arith.index_castui %[[INTPTR]] : index to i64 - // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR6:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 - // CHECK: %[[VAR7:.*]] = arith.addi %[[VAR5]], %[[VAR6]] : i64 - %3 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> - // CHECK: %[[C4_I64_1:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR8:.*]] = arith.muli %[[VAR1]], %[[C4_I64_1]] : i64 - // CHECK: %[[VAR9:.*]] = arith.addi %[[VAR7]], %[[VAR8]] : i64 - // CHECK: %[[VAR10:.*]] = llvm.inttoptr %[[VAR9]] : i64 to !llvm.ptr<1> - // CHECK: scf.if %[[VAR4]] { - // CHECK: llvm.store %[[CST_0]], %[[VAR10]] {cache_control = #xevm.store_cache_control} - // CHECK-SAME: : vector<1xf32>, !llvm.ptr<1> - xegpu.store %2, %3[%offset2], %1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xindex>, vector<1xi1> - gpu.return -} -} -// ----- - gpu.module @test { // CHECK-LABEL: @prefetch_ui64_src_constant_offset // CHECK-SAME: %[[ARG0:.*]]: ui64 gpu.func @prefetch_ui64_src_constant_offset(%src: ui64) { // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index - // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64 + // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64 %0 = arith.constant dense<0> : vector<1xindex> // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 @@ -288,12 +219,12 @@ gpu.module @test { // CHECK-LABEL: @prefetch_memref_src_constant_offset // CHECK-SAME: %[[ARG0:.*]]: memref<256xf32> gpu.func @prefetch_memref_src_constant_offset(%src: memref<256xf32>) { + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index + // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> - // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 %0 = arith.constant dense<0> : vector<1xindex> - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index - // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64 // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64 @@ -313,7 +244,7 @@ gpu.module @test { // CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex> gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) { // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> - // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 @@ -328,30 +259,3 @@ gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vecto gpu.return } } -// ----- - -gpu.module @test { -// CHECK-LABEL: @prefetch_memref_src_prefetch_offset -// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex> -gpu.func @prefetch_memref_src_prefetch_offset(%src: memref<256xf32>, %offset: vector<1xindex>, %offset2: vector<1xindex>) { - // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex> - // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64 - // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex> - // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64 - // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index - // CHECK: %[[VAR4:.*]] = arith.index_castui %[[INTPTR]] : index to i64 - // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64 - // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR4]], %[[VAR5]] : i64 - %1 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex> - -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr> - // CHECK: %[[C4_I64_0:.*]] = arith.constant 4 : i64 - // CHECK: %[[VAR7:.*]] = arith.muli %[[VAR1]], %[[C4_I64_0]] : i64 - // CHECK: %[[VAR8:.*]] = arith.addi %[[VAR6]], %[[VAR7]] : i64 - // CHECK: %[[VAR9:.*]] = llvm.inttoptr %[[VAR8]] : i64 to !llvm.ptr<1> - // CHECK: xevm.prefetch %[[VAR9]] <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) - xegpu.prefetch %1[%offset2] <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr>, vector<1xindex> - gpu.return -} -} diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir index a7ae4d9b7e4d2..8db0843de4cc1 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir @@ -1,45 +1,68 @@ -// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s +// RUN: mlir-opt -convert-xegpu-to-xevm --split-input-file %s | FileCheck %s gpu.module @materializecast { // CHECK-LABEL: gpu.func @materialize_memref // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32> gpu.func @materialize_memref(%src: memref<128xf32>) kernel { - // CHECK: XXX + // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index + // CHECK: %[[CASTED:.*]] = arith.index_castui %[[INTPTR]] : index to i64 %offset = arith.constant dense<0> : vector<1xindex> %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex> -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> gpu.return } +} + +// ----- +gpu.module @materializecast { // CHECK-LABEL: gpu.func @materialize_ui64 // CHECK-SAME: %[[ARG0:.*]]: ui64 gpu.func @materialize_ui64(%src: ui64) kernel { - // CHECK: XXX + // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 %offset = arith.constant dense<0> : vector<1xindex> %src_tdesc = xegpu.create_tdesc %src, %offset : ui64, vector<1xindex> -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> gpu.return } +} + +// ----- +gpu.module @materializecast { // CHECK-LABEL: gpu.func @materialize_ui32 // CHECK-SAME: %[[ARG0:.*]]: ui32 gpu.func @materialize_ui32(%src: ui32) kernel { + // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui32 to index + // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i32 %offset = arith.constant dense<0> : vector<1xindex> - //%src_tdesc = xegpu.create_tdesc %src, %offset : ui32, vector<1xindex> - // -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> + %src_tdesc = xegpu.create_tdesc %src, %offset : ui32, vector<1xindex> + -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> gpu.return } +} + +// ----- +gpu.module @materializecast { // CHECK-LABEL: gpu.func @materialize_single_index_vector // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32> gpu.func @materialize_single_index_vector(%src: memref<128xf32>) kernel { - // CHECK: XXX + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex> + // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex> + // CHECK: %[[VAR2:.*]] = arith.index_castui %[[VAR1]] : index to i64 %offset = arith.constant dense<0> : vector<1xindex> %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex> -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>> gpu.return } +} + +// ----- +gpu.module @materializecast { // CHECK-LABEL: gpu.func @materialize_single_elem_vector - // CHECK-SAME: %[[ARG0:.*]]: vector<1xi1> + // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32> gpu.func @materialize_single_elem_vector(%src: memref<128xf32>) kernel { - // CHECK: XXX + // CHECK: %[[CST:.*]] = arith.constant dense : vector<1xi1> + // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1> %mask = arith.constant dense<1>: vector<1xi1> %offset = arith.constant dense<0> : vector<1xindex> %0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> From d372592306858cc17a2cac6397e4755f5cdc826a Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Mon, 25 Aug 2025 21:11:26 +0000 Subject: [PATCH 07/18] Temp save. --- .../mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h | 8 +- .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 119 ++++++++---------- .../XeGPUToXeVM/create_nd_tdesc.mlir | 12 +- 3 files changed, 62 insertions(+), 77 deletions(-) diff --git a/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h index fb23d24b0161b..ddaaae82e03be 100644 --- a/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h +++ b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h @@ -5,8 +5,8 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -#ifndef MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_ -#define MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_ +#ifndef MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVM_H_ +#define MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVM_H_ #include @@ -20,8 +20,8 @@ class Pass; #include "mlir/Conversion/Passes.h.inc" void populateXeGPUToXeVMConversionPatterns( - mlir::RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter); + const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns); } // namespace mlir -#endif // MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_ +#endif // MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVM_H_ diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 6cfa8ac1f8fce..19324f748bded 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -57,7 +57,8 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { llvm_unreachable("Unknown XeGPU memory space."); } -VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) { +static VectorType encodeVectorTypeTo(VectorType currentVecType, + Type toElemType) { auto elemType = currentVecType.getElementType(); auto currentBitWidth = elemType.getIntOrFloatBitWidth(); auto newBitWidth = toElemType.getIntOrFloatBitWidth(); @@ -66,13 +67,11 @@ VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) { return VectorType::get(size, toElemType); } -xevm::LoadCacheControl +static xevm::LoadCacheControl translateLoadXeGPUCacheHint(std::optional L1hint, std::optional L3hint) { - auto L1hintVal = - L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED; - auto L3hintVal = - L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED; + auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED); + auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED); switch (L1hintVal) { case xegpu::CachePolicy::CACHED: if (L3hintVal == xegpu::CachePolicy::CACHED) @@ -102,13 +101,11 @@ translateLoadXeGPUCacheHint(std::optional L1hint, } } -xevm::StoreCacheControl +static xevm::StoreCacheControl translateStoreXeGPUCacheHint(std::optional L1hint, std::optional L3hint) { - auto L1hintVal = - L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED; - auto L3hintVal = - L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED; + auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED); + auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED); switch (L1hintVal) { case xegpu::CachePolicy::UNCACHED: if (L3hintVal == xegpu::CachePolicy::UNCACHED) @@ -152,10 +149,14 @@ class CreateNdDescToXeVMPattern ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto source = op.getSource(); + // Op is lowered to a code sequence that populates payload. + // payload is a 8xi32 vector. Type payloadElemTy = rewriter.getI32Type(); Type i64Ty = rewriter.getI64Type(); VectorType payloadTy = VectorType::get(8, payloadElemTy); + // 4xi64 view is used for inserting the base pointer. VectorType payloadI64Ty = VectorType::get(4, i64Ty); + // Initialize payload to zero. Value payload = arith::ConstantOp::create( rewriter, loc, DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0))); @@ -166,73 +167,56 @@ class CreateNdDescToXeVMPattern Value offsetW; Value offsetH; - bool sourceIsMemref = false; + // Source can be a memref or a pointer (ui64, ui32, i64 or i32). + SmallVector mixedSizes = op.getMixedSizes(); + SmallVector mixedOffsets = op.getMixedOffsets(); + // Descriptor shape is expected to be 2D. + int64_t rank = mixedSizes.size(); + if (rank != 2) + return rewriter.notifyMatchFailure(op, "Expected 2D shape."); auto sourceTy = source.getType(); - int64_t rank; - if (isa(sourceTy)) { - sourceIsMemref = true; + auto sourceMemrefTy = dyn_cast(sourceTy); + // If source is a memref, we need to extract the aligned pointer as index. + // pointer type is passed as i32 or i64 by type converter. + if (sourceMemrefTy) { baseAddr = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); - auto sourceMemrefTy = cast(sourceTy); if (!sourceMemrefTy.hasStaticShape()) { op.emitError() << "Expected static memref shape."; return failure(); } - rank = sourceMemrefTy.getRank(); - if (rank != 2) { - op.emitError() << "Expected a 2D memref."; - return failure(); - } - } else if (sourceTy == rewriter.getIntegerType(64, false)) { - rank = op.getMixedSizes().size(); } else { - op.emitError() << "Expected source to be a 2D memref or ui64."; - return failure(); + baseAddr = adaptor.getSource(); } - auto createOffset = [&](unsigned idx) -> Value { - Value val; - OpFoldResult ofr = op.getMixedOffsets()[idx]; - if (auto v = llvm::dyn_cast_if_present(ofr)) { - val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v); - val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val); - } else { - int32_t off = llvm::cast(cast(ofr)).getInt(); - val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off); - } + // utility for creating offset values from op fold result. + auto createOffset = [&](SmallVector &ofrVec, + unsigned idx) -> Value { + Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]); + val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val); return val; }; - auto offsets = op.getMixedOffsets(); - if (offsets.size() == 2) { - offsetW = createOffset(rank - 1); - offsetH = createOffset(rank - 2); - } else { + // Offsets can be either 2D or not provided (0 is used). + if (mixedOffsets.size() == 2) { + offsetW = createOffset(mixedOffsets, rank - 1); + offsetH = createOffset(mixedOffsets, rank - 2); + } else if (mixedOffsets.size() == 0) { offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); + } else { + return rewriter.notifyMatchFailure(op, + "Expected 2D offsets or no offsets."); } - auto createShape = [&](unsigned idx) -> Value { - Value val; - OpFoldResult ofr = op.getMixedSizes()[idx]; - if (auto v = llvm::dyn_cast_if_present(ofr)) { - val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v); - val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val); - } else { - int32_t off = llvm::cast(cast(ofr)).getInt(); - val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off); - } - return val; - }; - if (sourceIsMemref) { - auto sourceMemrefTy = cast(sourceTy); - baseShapeW = arith::ConstantIntOp::create( - rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 1)); - baseShapeH = arith::ConstantIntOp::create( - rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 2)); + // Get shape values from op fold results. + baseShapeW = createOffset(mixedSizes, rank - 1); + baseShapeH = createOffset(mixedSizes, rank - 2); + if (sourceMemrefTy) { + // cast index to i64. baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); - } else { - baseShapeW = createShape(rank - 1); - baseShapeH = createShape(rank - 2); - baseAddr = adaptor.getSource(); + } else if (baseAddr.getType() != i64Ty) { + // pointer type may be i32. Cast to i64 if needed. + baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); } + // Populate payload. Value payLoadAsI64 = vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload); payLoadAsI64 = @@ -429,9 +413,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { // Add a builder that creates // offset * elemByteSize + baseAddr -auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc, - Value baseAddr, Value offset, - int64_t elemByteSize) -> Value { +static auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc, + Value baseAddr, Value offset, + int64_t elemByteSize) -> Value { Value byteSize = arith::ConstantIntOp::create( rewriter, loc, rewriter.getI64Type(), elemByteSize); Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); @@ -701,6 +685,7 @@ class PrefetchToXeVMPattern : public OpConversionPattern { return success(); } }; + class FenceToXeVMPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -1007,7 +992,7 @@ struct ConvertXeGPUToXeVMPass target.addIllegalDialect(); RewritePatternSet patterns(&getContext()); - populateXeGPUToXeVMConversionPatterns(patterns, typeConverter); + populateXeGPUToXeVMConversionPatterns(typeConverter, patterns); scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, target); if (failed(applyPartialConversion(getOperation(), target, @@ -1021,7 +1006,7 @@ struct ConvertXeGPUToXeVMPass // Pattern Population //===----------------------------------------------------------------------===// void mlir::populateXeGPUToXeVMConversionPatterns( - RewritePatternSet &patterns, LLVMTypeConverter &typeConverter) { + const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add, LoadStorePrefetchNdToXeVMPattern, diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir index 7f5e3527a1594..ba7ece8ccbebe 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir @@ -11,10 +11,8 @@ gpu.module @create_nd_tdesc { // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32 // CHECK: %[[C0_I32_0:.*]] = arith.constant 0 : i32 - // CHECK: %[[VAR2:.*]] = arith.index_cast %[[ARG3]] : index to i64 - // CHECK: %[[VAR3:.*]] = arith.trunci %[[VAR2]] : i64 to i32 - // CHECK: %[[VAR4:.*]] = arith.index_cast %[[ARG2]] : index to i64 - // CHECK: %[[VAR5:.*]] = arith.trunci %[[VAR4]] : i64 to i32 + // CHECK: %[[VAR3:.*]] = arith.index_cast %[[ARG3]] : index to i32 + // CHECK: %[[VAR5:.*]] = arith.index_cast %[[ARG2]] : index to i32 // CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64> // CHECK: %[[VAR7:.*]] = vector.insert %[[VAR1]], %[[VAR6]] [0] : i64 into vector<4xi64> // CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32> @@ -32,8 +30,10 @@ gpu.module @create_nd_tdesc { // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index // CHECK: %[[C0_I32_2:.*]] = arith.constant 0 : i32 // CHECK: %[[C0_I32_3:.*]] = arith.constant 0 : i32 - // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32 - // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32 + // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64 + // CHECK: %[[C16_I32:.*]] = arith.trunci %c16_i64 : i64 to i32 + // CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64 + // CHECK: %[[C8_I32:.*]] = arith.trunci %c8_i64 : i64 to i32 // CHECK: %[[VAR13:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64> // CHECK: %[[VAR15:.*]] = vector.insert %[[VAR13]], %[[VAR14]] [0] : i64 into vector<4xi64> From d88d676129de7c10f4eebae4e1bc5d777ea91d21 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Mon, 25 Aug 2025 22:11:43 +0000 Subject: [PATCH 08/18] Update update_nd_tdesc. --- .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 19324f748bded..12af0af70177b 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -248,27 +248,26 @@ class UpdateNdOffsetToXeVMPattern xegpu::UpdateNdOffsetOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto offsets = op.getOffsets(); + auto mixedOffsets = op.getMixedOffsets(); + if (mixedOffsets.size() != 2) + return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); auto tdesc = adaptor.getTensorDesc(); - for (size_t offsetDim = 0; offsetDim < offsets.size(); offsetDim++) { - auto offset = offsets[offsetDim]; - if (auto cst = - dyn_cast_if_present(offset.getDefiningOp())) - if (auto attr = dyn_cast_if_present(cst.getValue()); - attr && !attr.getInt()) - continue; - const int offsetPos = - static_cast(offsetDim ? NdDescI32Layout::TensorOffsetW - : NdDescI32Layout::TensorOffsetH); - auto oldOffset = - vector::ExtractOp::create(rewriter, loc, tdesc, offsetPos); - offset = arith::IndexCastUIOp::create(rewriter, loc, - rewriter.getI32Type(), offset); - auto newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset); - tdesc = - vector::InsertOp::create(rewriter, loc, newOffset, tdesc, offsetPos); - } - rewriter.replaceOp(op, tdesc); + // utility for updating payload offset values from op fold result. + auto updateOffset = [&](unsigned idx, int payloadPos) -> Value { + Value offset = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]); + offset = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offset); + Value oldOffset = + vector::ExtractOp::create(rewriter, loc, tdesc, payloadPos); + Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset); + return vector::InsertOp::create(rewriter, loc, newOffset, tdesc, + payloadPos); + }; + auto val = + updateOffset(0, static_cast(NdDescI32Layout::TensorOffsetH)); + val = updateOffset(1, static_cast(NdDescI32Layout::TensorOffsetW)); + rewriter.replaceOp(op, val); return success(); } }; From 236343e189f8ebea578ce872a0f2206a31ab1536 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Mon, 25 Aug 2025 22:29:06 +0000 Subject: [PATCH 09/18] Temp save. --- .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 160 ++++++++++-------- .../Conversion/XeGPUToXeVM/loadstore_nd.mlir | 12 +- .../XeGPUToXeVM/materializecast.mlir | 2 +- .../Conversion/XeGPUToXeVM/prefetch_nd.mlir | 6 +- 4 files changed, 99 insertions(+), 81 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 12af0af70177b..963ab29695b1f 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -39,12 +39,13 @@ using namespace mlir; namespace { -enum class NdDescI32Layout : uint32_t { - BasePtr = 0, - BaseShapeW = 2, - BaseShapeH = 3, - TensorOffsetW = 4, - TensorOffsetH = 5 +// Offsets to individual fields of the 8xi32 layout nd tensor descriptor. +enum class NdTdescOffset : uint32_t { + BasePtr = 0, // Base pointer (i64) + BaseShapeW = 2, // Base shape width (i32) + BaseShapeH = 3, // Base shape height (i32) + TensorOffsetW = 4, // Tensor offset W (i32) + TensorOffsetH = 5 // Tensor offset H (i32) }; static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { @@ -57,6 +58,7 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { llvm_unreachable("Unknown XeGPU memory space."); } +// Get same bitwidth flat vector type of new element type. static VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) { auto elemType = currentVecType.getElementType(); @@ -221,20 +223,20 @@ class CreateNdDescToXeVMPattern vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload); payLoadAsI64 = vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64, - static_cast(NdDescI32Layout::BasePtr)); + static_cast(NdTdescOffset::BasePtr)); payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64); payload = vector::InsertOp::create(rewriter, loc, baseShapeW, payload, - static_cast(NdDescI32Layout::BaseShapeW)); + static_cast(NdTdescOffset::BaseShapeW)); payload = vector::InsertOp::create(rewriter, loc, baseShapeH, payload, - static_cast(NdDescI32Layout::BaseShapeH)); + static_cast(NdTdescOffset::BaseShapeH)); payload = vector::InsertOp::create( rewriter, loc, offsetW, payload, - static_cast(NdDescI32Layout::TensorOffsetW)); + static_cast(NdTdescOffset::TensorOffsetW)); payload = vector::InsertOp::create( rewriter, loc, offsetH, payload, - static_cast(NdDescI32Layout::TensorOffsetH)); + static_cast(NdTdescOffset::TensorOffsetH)); rewriter.replaceOp(op, payload); return success(); } @@ -249,6 +251,7 @@ class UpdateNdOffsetToXeVMPattern ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto mixedOffsets = op.getMixedOffsets(); + // Only 2D offsets are supported for now. if (mixedOffsets.size() != 2) return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); auto tdesc = adaptor.getTensorDesc(); @@ -264,9 +267,9 @@ class UpdateNdOffsetToXeVMPattern return vector::InsertOp::create(rewriter, loc, newOffset, tdesc, payloadPos); }; - auto val = - updateOffset(0, static_cast(NdDescI32Layout::TensorOffsetH)); - val = updateOffset(1, static_cast(NdDescI32Layout::TensorOffsetW)); + // Update offsets in the payload. + auto val = updateOffset(0, static_cast(NdTdescOffset::TensorOffsetH)); + val = updateOffset(1, static_cast(NdTdescOffset::TensorOffsetW)); rewriter.replaceOp(op, val); return success(); } @@ -293,62 +296,46 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); Value payLoadAsI64 = vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); - Value basePtr = - vector::ExtractOp::create(rewriter, loc, payLoadAsI64, - static_cast(NdDescI32Layout::BasePtr)); + Value basePtr = vector::ExtractOp::create( + rewriter, loc, payLoadAsI64, static_cast(NdTdescOffset::BasePtr)); Value baseShapeW = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast(NdDescI32Layout::BaseShapeW)); + rewriter, loc, tdesc, static_cast(NdTdescOffset::BaseShapeW)); Value baseShapeH = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast(NdDescI32Layout::BaseShapeH)); - // Offsets can come from three sources: - // 1. Constant offsets, which are provided by the op. - // 2. Offsets as operands, which are provided by the op. - // 3. Offsets extracted from the tensor descriptor. + rewriter, loc, tdesc, static_cast(NdTdescOffset::BaseShapeH)); + // Offsets provided in two ways: + // 1. Offsets are extracted from the tensor descriptor. + // 2. (Mixed) offsets which are provided by the op. Value offsetW; Value offsetH; - auto cOffsets = op.getConstOffsets(); - auto offsets = op.getOffsets(); - if (cOffsets) { - offsetW = arith::ConstantIntOp::create( - rewriter, loc, rewriter.getI32Type(), (*cOffsets)[0]); - offsetH = arith::ConstantIntOp::create( - rewriter, loc, rewriter.getI32Type(), (*cOffsets)[1]); - } else if (offsets.size() != 0) { - // offsets are provided as operands - if (offsets[0].getType() != rewriter.getI32Type()) { - if (offsets[0].getType() != rewriter.getIndexType()) { - return rewriter.notifyMatchFailure( - op, "Expected offsets to be of type i32 or index."); - } - offsetW = arith::IndexCastUIOp::create( - rewriter, loc, rewriter.getI32Type(), offsets[0]); - } else { - offsetW = offsets[0]; - } - if (offsets[1].getType() != rewriter.getI32Type()) { - if (offsets[1].getType() != rewriter.getIndexType()) { - return rewriter.notifyMatchFailure( - op, "Expected offsets to be of type i32 or index."); - } - offsetH = arith::IndexCastUIOp::create( - rewriter, loc, rewriter.getI32Type(), offsets[1]); - } else { - offsetH = offsets[1]; - } + auto mixedOffsets = op.getMixedOffsets(); + int64_t opOffsetsSize = mixedOffsets.size(); + if (opOffsetsSize != 0 && opOffsetsSize != 2) { + return rewriter.notifyMatchFailure(op, + "Expected 2D offsets or no offsets."); + } + if (opOffsetsSize) { + // If mixed offsets are provided by the op convert them to i32. + offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); + offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetW); + offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetH); } else { // If offsets are not available, we need to extract them from the tensor // descriptor. offsetW = vector::ExtractOp::create( - rewriter, loc, tdesc, - static_cast(NdDescI32Layout::TensorOffsetW)); + rewriter, loc, tdesc, static_cast(NdTdescOffset::TensorOffsetW)); offsetH = vector::ExtractOp::create( - rewriter, loc, tdesc, - static_cast(NdDescI32Layout::TensorOffsetH)); + rewriter, loc, tdesc, static_cast(NdTdescOffset::TensorOffsetH)); } + // Get address space from tensor descriptor memory space. auto ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); + // Convert base pointer (i64) to LLVM pointer type. Value basePtrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); + // Compute element byte size and surface width in bytes. auto elemType = tdescTy.getElementType(); auto elemBitSize = elemType.getIntOrFloatBitWidth(); Value elemByteSize = arith::ConstantIntOp::create( @@ -356,23 +343,27 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { Value surfaceW = arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); + // Get tile sizes and vblocks from the tensor descriptor type. auto tileW = tdescTy.getDimSize(1); auto tileH = tdescTy.getDimSize(0); int32_t vblocks = tdescTy.getArrayLength(); if constexpr (std::is_same_v) { - VectorType srcVecTy = cast(op.getValue().getType()); + VectorType srcVecTy = dyn_cast(adaptor.getValue().getType()); + if (!srcVecTy) { + return rewriter.notifyMatchFailure( + op, "Expected store value to be a vector type."); + } auto storeCacheControl = translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); - VectorType srcFlatVecTy = - VectorType::get(srcVecTy.getNumElements(), srcVecTy.getElementType()); - Value srcFlatVec = op.getValue(); - srcFlatVecTy = encodeVectorTypeTo(srcFlatVecTy, - rewriter.getIntegerType(elemBitSize)); - srcFlatVec = - vector::BitCastOp::create(rewriter, loc, srcFlatVecTy, srcFlatVec); + Value src = adaptor.getValue(); + // Get flat vector type of integer type with matching element bit size. + VectorType newSrcVecTy = + encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); + if (srcVecTy != newSrcVecTy) + src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); xevm::BlockStore2dOp::create( rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, - offsetH, elemBitSize, tileW, tileH, srcFlatVec, + offsetH, elemBitSize, tileW, tileH, src, xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); rewriter.eraseOp(op); } else { @@ -412,15 +403,14 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { // Add a builder that creates // offset * elemByteSize + baseAddr -static auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc, - Value baseAddr, Value offset, - int64_t elemByteSize) -> Value { +static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, + Value baseAddr, Value offset, int64_t elemByteSize) { Value byteSize = arith::ConstantIntOp::create( rewriter, loc, rewriter.getI64Type(), elemByteSize); Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset); return newAddr; -}; +} class CreateDescToXeVMPattern : public OpConversionPattern { @@ -908,6 +898,10 @@ struct ConvertXeGPUToXeVMPass return IntegerType::get(&getContext(), 64); }); + // LLVM type converter puts unrealized casts for the following cases: + // add materialization casts to handle them. + + // Materialization to convert memref to i64 auto memrefMaterializationCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { @@ -924,6 +918,7 @@ struct ConvertXeGPUToXeVMPass return {}; }; + // Materialization to convert ui64 to i64 auto ui64MaterializationCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { @@ -940,6 +935,7 @@ struct ConvertXeGPUToXeVMPass return {}; }; + // Materialization to convert ui32 to i32 auto ui32MaterializationCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { @@ -956,9 +952,13 @@ struct ConvertXeGPUToXeVMPass return {}; }; - auto vector1DMaterializationCast = [](OpBuilder &builder, Type type, - ValueRange inputs, - Location loc) -> Value { + // Materialization to convert + // - single element 1D vector to scalar + // - bitcast vector of same rank + // - shape vector of different rank but same element type + auto vectorMaterializationCast = [](OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) -> Value { if (inputs.size() != 1) return {}; auto input = inputs.front(); @@ -971,6 +971,18 @@ struct ConvertXeGPUToXeVMPass cast = arith::IndexCastUIOp::create(builder, loc, type, cast) .getResult(); return cast; + } else if (auto targetVecTy = dyn_cast(type)) { + // If the target type is a vector of same rank, + // bitcast to the target type. + if (targetVecTy.getRank() == vecTy.getRank()) + return vector::BitCastOp::create(builder, loc, targetVecTy, input) + .getResult(); + else if (targetVecTy.getElementType() == vecTy.getElementType()) { + // If the target type is a vector of different rank but same element + // type, reshape to the target type. + return vector::ShapeCastOp::create(builder, loc, targetVecTy, input) + .getResult(); + } } } return {}; @@ -978,11 +990,11 @@ struct ConvertXeGPUToXeVMPass typeConverter.addSourceMaterialization(memrefMaterializationCast); typeConverter.addSourceMaterialization(ui64MaterializationCast); typeConverter.addSourceMaterialization(ui32MaterializationCast); - typeConverter.addSourceMaterialization(vector1DMaterializationCast); + typeConverter.addSourceMaterialization(vectorMaterializationCast); typeConverter.addTargetMaterialization(memrefMaterializationCast); typeConverter.addTargetMaterialization(ui32MaterializationCast); typeConverter.addTargetMaterialization(ui64MaterializationCast); - typeConverter.addTargetMaterialization(vector1DMaterializationCast); + typeConverter.addTargetMaterialization(vectorMaterializationCast); ConversionTarget target(getContext()); target.addLegalDialect //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32> //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32> - //CHECK: %[[LD_TILE_W:.*]] = arith.constant 0 : i32 - //CHECK: %[[LD_TILE_H:.*]] = arith.constant 0 : i32 + //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64 + //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32 + //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64 + //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32 //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1> //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32 //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32 @@ -54,8 +56,10 @@ gpu.module @load_store_check { //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64> //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32> //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32> - //CHECK: %[[TILE_W:.*]] = arith.constant 0 : i32 - //CHECK: %[[TILE_H:.*]] = arith.constant 0 : i32 + //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64 + //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32 + //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64 + //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32 //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1> //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32 //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32 diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir index 8db0843de4cc1..2445c4b341657 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir @@ -66,7 +66,7 @@ gpu.module @materializecast { %mask = arith.constant dense<1>: vector<1xi1> %offset = arith.constant dense<0> : vector<1xindex> %0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> - : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<1x8xf32> + : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<8xf32> gpu.return } } diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir index 8513b4f9857fb..873478aed57e3 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir @@ -20,8 +20,10 @@ gpu.module @fence_check { //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64> //CHECK: %[[PREF_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32> //CHECK: %[[PREF_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32> - //CHECK: %[[PREF_TILE_W:.*]] = arith.constant 0 : i32 - //CHECK: %[[PREF_TILE_H:.*]] = arith.constant 0 : i32 + //CHECK: %[[PREF_TILE_W64:.*]] = arith.constant 0 : i64 + //CHECK: %[[PREF_TILE_W:.*]] = arith.trunci %[[PREF_TILE_W64]] : i64 to i32 + //CHECK: %[[PREF_TILE_H64:.*]] = arith.constant 0 : i64 + //CHECK: %[[PREF_TILE_H:.*]] = arith.trunci %[[PREF_TILE_H64]] : i64 to i32 //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1> //CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32 //CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[PREF_BASE_W]], %[[PREF_SIZEOF_F32]] : i32 From a8cd5e08cfc536fb27f41f23820b21d0883afd8f Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Tue, 26 Aug 2025 00:49:47 +0000 Subject: [PATCH 10/18] Address comments. --- .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 963ab29695b1f..6cd50a38a21a4 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -152,7 +152,7 @@ class CreateNdDescToXeVMPattern auto loc = op.getLoc(); auto source = op.getSource(); // Op is lowered to a code sequence that populates payload. - // payload is a 8xi32 vector. + // Payload is a 8xi32 vector. Type payloadElemTy = rewriter.getI32Type(); Type i64Ty = rewriter.getI64Type(); VectorType payloadTy = VectorType::get(8, payloadElemTy); @@ -179,7 +179,7 @@ class CreateNdDescToXeVMPattern auto sourceTy = source.getType(); auto sourceMemrefTy = dyn_cast(sourceTy); // If source is a memref, we need to extract the aligned pointer as index. - // pointer type is passed as i32 or i64 by type converter. + // Pointer type is passed as i32 or i64 by type converter. if (sourceMemrefTy) { baseAddr = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); @@ -190,7 +190,7 @@ class CreateNdDescToXeVMPattern } else { baseAddr = adaptor.getSource(); } - // utility for creating offset values from op fold result. + // Utility for creating offset values from op fold result. auto createOffset = [&](SmallVector &ofrVec, unsigned idx) -> Value { Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]); @@ -212,10 +212,10 @@ class CreateNdDescToXeVMPattern baseShapeW = createOffset(mixedSizes, rank - 1); baseShapeH = createOffset(mixedSizes, rank - 2); if (sourceMemrefTy) { - // cast index to i64. + // Cast index to i64. baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); } else if (baseAddr.getType() != i64Ty) { - // pointer type may be i32. Cast to i64 if needed. + // Pointer type may be i32. Cast to i64 if needed. baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); } // Populate payload. @@ -255,7 +255,7 @@ class UpdateNdOffsetToXeVMPattern if (mixedOffsets.size() != 2) return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); auto tdesc = adaptor.getTensorDesc(); - // utility for updating payload offset values from op fold result. + // Utility for updating payload offset values from op fold result. auto updateOffset = [&](unsigned idx, int payloadPos) -> Value { Value offset = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]); @@ -425,7 +425,7 @@ class CreateDescToXeVMPattern op, "Expected element type bit width to be multiple of 8."); } auto loc = op.getLoc(); - // offsets are provided as scalar i64 by type converter. + // Offsets are provided as scalar i64 by type converter. auto offsets = adaptor.getOffsets(); // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32). // But type converter will convert them to integer types. @@ -453,8 +453,8 @@ class UpdateOffsetToXeVMPattern op, "Expected element type bit width to be multiple of 8."); } auto loc = op.getLoc(); - // scatter descriptor is provided as scalar i64 by type converter. - // offsets are provided as scalar i64 by type converter. + // Scatter descriptor is provided as scalar i64 by type converter. + // Offsets are provided as scalar i64 by type converter. Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(), adaptor.getOffsets(), eBw / 8); rewriter.replaceOp(op, newOffset); @@ -583,7 +583,7 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { scf::YieldOp::create(rewriter, loc, ValueRange{loaded}); rewriter.replaceOp(op, ifOp.getResult(0)); } else { - // if mask is true, perform the store. + // If mask is true, perform the store. scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false); auto body = ifOp.getBody(); rewriter.setInsertionPointToStart(body); @@ -758,7 +758,7 @@ class DpasToXeVMPattern : public OpConversionPattern { VectorType::get(cvecty.getNumElements(), cvecty.getElementType()); if (cvecty != cNty) c = vector::ShapeCastOp::create(rewriter, loc, cNty, c); - // below are uArch dependent values, should move away from hardcoding + // Below are uArch dependent values, should move away from hardcoding constexpr int32_t systolicDepth{8}; constexpr int32_t executionSize{16}; Value dpasRes = xevm::MMAOp::create( @@ -818,7 +818,6 @@ matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) { default: return std::nullopt; } - llvm_unreachable("Invalid AtomicRMWKind"); } class AtomicRMWToXeVMPattern : public OpConversionPattern { From 953b8508152bab32cd905cd582c29c5340adcf2f Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Tue, 26 Aug 2025 01:00:46 +0000 Subject: [PATCH 11/18] Remove unneeded llvm_unreachable. --- mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 6cd50a38a21a4..172b09bacdc03 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -55,7 +55,6 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { case xegpu::MemorySpace::SLM: return static_cast(xevm::AddrSpace::SHARED); } - llvm_unreachable("Unknown XeGPU memory space."); } // Get same bitwidth flat vector type of new element type. @@ -689,7 +688,6 @@ class FenceToXeVMPattern : public OpConversionPattern { case xegpu::FenceScope::GPU: memScope = xevm::MemScope::DEVICE; break; - llvm_unreachable("Unknown XeGPU fence scope."); } xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL}; switch (op.getMemoryKind()) { @@ -699,7 +697,6 @@ class FenceToXeVMPattern : public OpConversionPattern { case xegpu::MemorySpace::SLM: addrSpace = xevm::AddrSpace::SHARED; break; - llvm_unreachable("Unknown XeGPU fence scope."); } xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace); rewriter.eraseOp(op); From 10a6aff857c9a033f290a1ed1fcf73f466a24cb8 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Tue, 26 Aug 2025 01:21:25 +0000 Subject: [PATCH 12/18] Remove redundant braces. --- .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 54 +++++++------------ 1 file changed, 20 insertions(+), 34 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 172b09bacdc03..ae42489196303 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -210,13 +210,13 @@ class CreateNdDescToXeVMPattern // Get shape values from op fold results. baseShapeW = createOffset(mixedSizes, rank - 1); baseShapeH = createOffset(mixedSizes, rank - 2); - if (sourceMemrefTy) { + if (sourceMemrefTy) // Cast index to i64. baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); - } else if (baseAddr.getType() != i64Ty) { + else if (baseAddr.getType() != i64Ty) // Pointer type may be i32. Cast to i64 if needed. baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); - } + // Populate payload. Value payLoadAsI64 = vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload); @@ -288,9 +288,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { auto tdesc = adaptor.getTensorDesc(); auto tdescTy = op.getTensorDescType(); - if (tdescTy.getRank() != 2) { + if (tdescTy.getRank() != 2) return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor."); - } VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); Value payLoadAsI64 = @@ -308,10 +307,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { Value offsetH; auto mixedOffsets = op.getMixedOffsets(); int64_t opOffsetsSize = mixedOffsets.size(); - if (opOffsetsSize != 0 && opOffsetsSize != 2) { + if (opOffsetsSize != 0 && opOffsetsSize != 2) return rewriter.notifyMatchFailure(op, "Expected 2D offsets or no offsets."); - } if (opOffsetsSize) { // If mixed offsets are provided by the op convert them to i32. offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); @@ -348,10 +346,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { int32_t vblocks = tdescTy.getArrayLength(); if constexpr (std::is_same_v) { VectorType srcVecTy = dyn_cast(adaptor.getValue().getType()); - if (!srcVecTy) { + if (!srcVecTy) return rewriter.notifyMatchFailure( op, "Expected store value to be a vector type."); - } auto storeCacheControl = translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); Value src = adaptor.getValue(); @@ -419,10 +416,9 @@ class CreateDescToXeVMPattern ConversionPatternRewriter &rewriter) const override { auto eTy = op.getTensorDescType().getElementType(); auto eBw = eTy.getIntOrFloatBitWidth(); - if (eBw % 8 != 0) { + if (eBw % 8 != 0) return rewriter.notifyMatchFailure( op, "Expected element type bit width to be multiple of 8."); - } auto loc = op.getLoc(); // Offsets are provided as scalar i64 by type converter. auto offsets = adaptor.getOffsets(); @@ -447,10 +443,9 @@ class UpdateOffsetToXeVMPattern ConversionPatternRewriter &rewriter) const override { auto eTy = op.getTensorDescType().getElementType(); auto eBw = eTy.getIntOrFloatBitWidth(); - if (eBw % 8 != 0) { + if (eBw % 8 != 0) return rewriter.notifyMatchFailure( op, "Expected element type bit width to be multiple of 8."); - } auto loc = op.getLoc(); // Scatter descriptor is provided as scalar i64 by type converter. // Offsets are provided as scalar i64 by type converter. @@ -475,30 +470,27 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { Value basePtrI64; // Load result or Store valye Type can be vector or scalar. Type valOrResTy; - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) valOrResTy = op.getResult().getType(); - } else { + else valOrResTy = adaptor.getValue().getType(); - } VectorType valOrResVecTy = dyn_cast(valOrResTy); bool hasScalarVal = !valOrResVecTy; int64_t elemBitWidth = hasScalarVal ? valOrResTy.getIntOrFloatBitWidth() : valOrResVecTy.getElementType().getIntOrFloatBitWidth(); // Element type must be multiple of 8 bits. - if (elemBitWidth % 8 != 0) { + if (elemBitWidth % 8 != 0) return rewriter.notifyMatchFailure( op, "Expected element type bit width to be multiple of 8."); - } int64_t elemByteSize = elemBitWidth / 8; // Default memory space is global. LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global)); // If tensor descriptor is available, we use its memory space. - if (tdescTy) { + if (tdescTy) ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); - } // Base pointer can come from source (load) or dest (store). // If they are memrefs, we use their memory space. if constexpr (std::is_same_v) { @@ -524,18 +516,17 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { Value offsets = adaptor.getOffsets(); Value mask = adaptor.getMask(); if (offsets) { - if (dyn_cast(offsets.getType())) { + if (dyn_cast(offsets.getType())) // Offset needs be scalar. Single element vector is converted to scalar // by type converter. return rewriter.notifyMatchFailure(op, "Expected offsets to be a scalar."); - } else { + else // If offsets are provided, we add them to the base pointer. // Offsets are in number of elements, we need to multiply by // element byte size. basePtrI64 = addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); - } } // Convert base pointer (i64) to LLVM pointer type. Value basePtrLLVM = @@ -543,13 +534,12 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { Value maskForLane; VectorType maskVecTy = dyn_cast(mask.getType()); - if (maskVecTy) { + if (maskVecTy) // Mask needs be scalar. Single element vector is converted to scalar by // type converter. return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar."); - } else { + else maskForLane = mask; - } if constexpr (std::is_same_v) { scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy}, maskForLane, true, true); @@ -609,10 +599,9 @@ class PrefetchToXeVMPattern : public OpConversionPattern { auto tdescTy = op.getTensorDescType(); Value basePtrI64 = adaptor.getSource(); // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed. - if (basePtrI64.getType() != rewriter.getI64Type()) { + if (basePtrI64.getType() != rewriter.getI64Type()) basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), basePtrI64); - } Value offsets = adaptor.getOffsets(); if (offsets) { VectorType offsetsVecTy = dyn_cast(offsets.getType()); @@ -637,10 +626,9 @@ class PrefetchToXeVMPattern : public OpConversionPattern { elemByteSize = *op.getOffsetAlignByte(); } if (elemBitWidth != 0) { - if (elemBitWidth % 8 != 0) { + if (elemBitWidth % 8 != 0) return rewriter.notifyMatchFailure( op, "Expected element type bit width to be multiple of 8."); - } elemByteSize = elemBitWidth / 8; } basePtrI64 = @@ -651,10 +639,9 @@ class PrefetchToXeVMPattern : public OpConversionPattern { LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global)); // If tensor descriptor is available, we use its memory space. - if (tdescTy) { + if (tdescTy) ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); - } // If source is a memref, we use its memory space. if (auto memRefTy = dyn_cast(op.getSource().getType())) { auto addrSpace = memRefTy.getMemorySpaceAsInt(); @@ -883,9 +870,8 @@ struct ConvertXeGPUToXeVMPass return VectorType::get(sum, elemType); }); typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type { - if (type.isScattered()) { + if (type.isScattered()) return IntegerType::get(&getContext(), 64); - } auto i32Type = IntegerType::get(&getContext(), 32); return VectorType::get(8, i32Type); }); From dea2933d2db7100cc84556f6d9321f3a82bba331 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Tue, 26 Aug 2025 14:52:08 +0000 Subject: [PATCH 13/18] Add element bitwidth restriction. --- mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index ae42489196303..712ed1ee88988 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -290,6 +290,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { auto tdescTy = op.getTensorDescType(); if (tdescTy.getRank() != 2) return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor."); + auto elemType = tdescTy.getElementType(); + auto elemBitSize = elemType.getIntOrFloatBitWidth(); + if (elemBitSize % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); Value payLoadAsI64 = @@ -333,8 +338,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { Value basePtrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); // Compute element byte size and surface width in bytes. - auto elemType = tdescTy.getElementType(); - auto elemBitSize = elemType.getIntOrFloatBitWidth(); Value elemByteSize = arith::ConstantIntOp::create( rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); Value surfaceW = From b01086ed5279acc7018897708808493e8a4b3d98 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Wed, 27 Aug 2025 17:19:37 +0000 Subject: [PATCH 14/18] Address comments. --- mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 712ed1ee88988..8d2ad4e999d38 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -151,10 +151,11 @@ class CreateNdDescToXeVMPattern auto loc = op.getLoc(); auto source = op.getSource(); // Op is lowered to a code sequence that populates payload. - // Payload is a 8xi32 vector. + // Payload is a 8xi32 vector. Offset to individual fields are defined in + // NdTdescOffset enum. Type payloadElemTy = rewriter.getI32Type(); - Type i64Ty = rewriter.getI64Type(); VectorType payloadTy = VectorType::get(8, payloadElemTy); + Type i64Ty = rewriter.getI64Type(); // 4xi64 view is used for inserting the base pointer. VectorType payloadI64Ty = VectorType::get(4, i64Ty); // Initialize payload to zero. @@ -180,12 +181,12 @@ class CreateNdDescToXeVMPattern // If source is a memref, we need to extract the aligned pointer as index. // Pointer type is passed as i32 or i64 by type converter. if (sourceMemrefTy) { - baseAddr = - memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); if (!sourceMemrefTy.hasStaticShape()) { op.emitError() << "Expected static memref shape."; return failure(); } + baseAddr = + memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); } else { baseAddr = adaptor.getSource(); } @@ -198,8 +199,8 @@ class CreateNdDescToXeVMPattern }; // Offsets can be either 2D or not provided (0 is used). if (mixedOffsets.size() == 2) { - offsetW = createOffset(mixedOffsets, rank - 1); - offsetH = createOffset(mixedOffsets, rank - 2); + offsetW = createOffset(mixedOffsets, 1); + offsetH = createOffset(mixedOffsets, 0); } else if (mixedOffsets.size() == 0) { offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); @@ -208,8 +209,8 @@ class CreateNdDescToXeVMPattern "Expected 2D offsets or no offsets."); } // Get shape values from op fold results. - baseShapeW = createOffset(mixedSizes, rank - 1); - baseShapeH = createOffset(mixedSizes, rank - 2); + baseShapeW = createOffset(mixedSizes, 1); + baseShapeH = createOffset(mixedSizes, 0); if (sourceMemrefTy) // Cast index to i64. baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); From a2887f24e22dc1293f4c8851ae1d78748a357d17 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Wed, 27 Aug 2025 18:25:59 +0000 Subject: [PATCH 15/18] Address comments. --- .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 20 +++--- .../XeGPUToXeVM/create_nd_tdesc.mlir | 64 ++++++++++++------- mlir/test/Conversion/XeGPUToXeVM/dpas.mlir | 3 +- 3 files changed, 53 insertions(+), 34 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 8d2ad4e999d38..187d8d805a06e 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -1,4 +1,4 @@ -//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===// +//===-- XeGPUToXeVM.cpp - XeVM to LLVM dialect conversion -------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -182,8 +182,7 @@ class CreateNdDescToXeVMPattern // Pointer type is passed as i32 or i64 by type converter. if (sourceMemrefTy) { if (!sourceMemrefTy.hasStaticShape()) { - op.emitError() << "Expected static memref shape."; - return failure(); + return rewriter.notifyMatchFailure(op, "Expected static memref shape."); } baseAddr = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); @@ -211,13 +210,13 @@ class CreateNdDescToXeVMPattern // Get shape values from op fold results. baseShapeW = createOffset(mixedSizes, 1); baseShapeH = createOffset(mixedSizes, 0); - if (sourceMemrefTy) + if (sourceMemrefTy) { // Cast index to i64. baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); - else if (baseAddr.getType() != i64Ty) + } else if (baseAddr.getType() != i64Ty) { // Pointer type may be i32. Cast to i64 if needed. baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); - + } // Populate payload. Value payLoadAsI64 = vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload); @@ -520,17 +519,18 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { Value offsets = adaptor.getOffsets(); Value mask = adaptor.getMask(); if (offsets) { - if (dyn_cast(offsets.getType())) + if (dyn_cast(offsets.getType())) { // Offset needs be scalar. Single element vector is converted to scalar // by type converter. return rewriter.notifyMatchFailure(op, "Expected offsets to be a scalar."); - else + } else { // If offsets are provided, we add them to the base pointer. // Offsets are in number of elements, we need to multiply by // element byte size. basePtrI64 = addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); + } } // Convert base pointer (i64) to LLVM pointer type. Value basePtrLLVM = @@ -538,11 +538,11 @@ class LoadStoreToXeVMPattern : public OpConversionPattern { Value maskForLane; VectorType maskVecTy = dyn_cast(mask.getType()); - if (maskVecTy) + if (maskVecTy) { // Mask needs be scalar. Single element vector is converted to scalar by // type converter. return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar."); - else + } else maskForLane = mask; if constexpr (std::is_same_v) { scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy}, diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir index ba7ece8ccbebe..4ff95b40fe68c 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir @@ -2,24 +2,24 @@ gpu.module @create_nd_tdesc { // CHECK-LABEL: gpu.func @create_nd_tdesc - // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64 - // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index + // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64, + // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index, - %stride1: index, %stride2: index) kernel { + %stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel { // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index - // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64 + // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32> - // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32 - // CHECK: %[[C0_I32_0:.*]] = arith.constant 0 : i32 - // CHECK: %[[VAR3:.*]] = arith.index_cast %[[ARG3]] : index to i32 - // CHECK: %[[VAR5:.*]] = arith.index_cast %[[ARG2]] : index to i32 + // CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32 + // CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32 + // CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32 + // CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32 // CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64> - // CHECK: %[[VAR7:.*]] = vector.insert %[[VAR1]], %[[VAR6]] [0] : i64 into vector<4xi64> + // CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64> // CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[VAR9:.*]] = vector.insert %[[VAR3]], %[[VAR8]] [2] : i32 into vector<8xi32> - // CHECK: %[[VAR10:.*]] = vector.insert %[[VAR5]], %[[VAR9]] [3] : i32 into vector<8xi32> - // CHECK: %[[VAR11:.*]] = vector.insert %[[C0_I32]], %[[VAR10]] [4] : i32 into vector<8xi32> - // CHECK: %[[VAR12:.*]] = vector.insert %[[C0_I32_0]], %[[VAR11]] [5] : i32 into vector<8xi32> + // CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32> + // CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32> + // CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32> + // CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32> %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2] : ui64 -> !xegpu.tensor_desc<8x16xf32> @@ -28,21 +28,39 @@ gpu.module @create_nd_tdesc { // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32> // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index - // CHECK: %[[C0_I32_2:.*]] = arith.constant 0 : i32 - // CHECK: %[[C0_I32_3:.*]] = arith.constant 0 : i32 + // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32 + // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32 // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64 - // CHECK: %[[C16_I32:.*]] = arith.trunci %c16_i64 : i64 to i32 + // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %c16_i64 : i64 to i32 // CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64 - // CHECK: %[[C8_I32:.*]] = arith.trunci %c8_i64 : i64 to i32 - // CHECK: %[[VAR13:.*]] = arith.index_castui %[[INTPTR]] : index to i64 + // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %c8_i64 : i64 to i32 + // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64 // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64> - // CHECK: %[[VAR15:.*]] = vector.insert %[[VAR13]], %[[VAR14]] [0] : i64 into vector<4xi64> + // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64> // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32> - // CHECK: %[[VAR17:.*]] = vector.insert %[[C16_I32]], %[[VAR16]] [2] : i32 into vector<8xi32> - // CHECK: %[[VAR18:.*]] = vector.insert %[[C8_I32]], %[[VAR17]] [3] : i32 into vector<8xi32> - // CHECK: %[[VAR19:.*]] = vector.insert %[[C0_I32_2]], %[[VAR18]] [4] : i32 into vector<8xi32> - // CHECK: %[[VAR20:.*]] = vector.insert %[[C0_I32_3]], %[[VAR19]] [5] : i32 into vector<8xi32> + // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32> + // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32> + // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32> + // CHECK: %[[VAR20:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32> %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + + // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32> + // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index + // CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32 + // CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32 + // CHECK: %[[C16_I64_6:.*]] = arith.constant 16 : i64 + // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C16_I64_6]] : i64 to i32 + // CHECK: %[[C8_I64_7:.*]] = arith.constant 8 : i64 + // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C8_I64_7]] : i64 to i32 + // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64 + // CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64> + // CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64> + // CHECK: %[[VAR28:.*]] = vector.bitcast %[[VAR27]] : vector<4xi64> to vector<8xi32> + // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR28]] [2] : i32 into vector<8xi32> + // CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32> + // CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32> + // CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32> + %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> gpu.return } } diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir index 15940fc4aca26..e6f22f0a9acbb 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir @@ -5,7 +5,8 @@ #sg_map_c_f32 = #xegpu.layout gpu.module @load_store_check { - //CHECK: func.func @dpas(%[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>) -> vector<8xf32> + // CHECK-LABEL: func.func @dpas( + // CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32> func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> { // Loads are checked in a separate test. // CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = , types = } From a6833962ae48c49f2583165985d7db1e2761f852 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Wed, 27 Aug 2025 18:32:33 +0000 Subject: [PATCH 16/18] Add test case description. --- mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir index 2445c4b341657..b28a8c2ccf843 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir @@ -1,5 +1,8 @@ // RUN: mlir-opt -convert-xegpu-to-xevm --split-input-file %s | FileCheck %s +// This file contains tests for materalization patterns added to handle custom type conversions +// added on top of LLVM type converter. + gpu.module @materializecast { // CHECK-LABEL: gpu.func @materialize_memref // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32> From 8f51ef433f5f0553a130df82d909fd63a0712f62 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Wed, 27 Aug 2025 21:31:22 +0000 Subject: [PATCH 17/18] Address comments. --- mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 187d8d805a06e..906b943a98756 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -30,6 +30,8 @@ #include "llvm/ADT/TypeSwitch.h" +#include + namespace mlir { #define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS #include "mlir/Conversion/Passes.h.inc" @@ -39,6 +41,10 @@ using namespace mlir; namespace { +// TODO: Below are uArch dependent values, should move away from hardcoding +static constexpr int32_t systolicDepth{8}; +static constexpr int32_t executionSize{16}; + // Offsets to individual fields of the 8xi32 layout nd tensor descriptor. enum class NdTdescOffset : uint32_t { BasePtr = 0, // Base pointer (i64) @@ -746,9 +752,6 @@ class DpasToXeVMPattern : public OpConversionPattern { VectorType::get(cvecty.getNumElements(), cvecty.getElementType()); if (cvecty != cNty) c = vector::ShapeCastOp::create(rewriter, loc, cNty, c); - // Below are uArch dependent values, should move away from hardcoding - constexpr int32_t systolicDepth{8}; - constexpr int32_t executionSize{16}; Value dpasRes = xevm::MMAOp::create( rewriter, loc, cNty, aVec, bVec, c, xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize, @@ -867,10 +870,9 @@ struct ConvertXeGPUToXeVMPass if (rank < 1 || type.getNumElements() == 1) return elemType; // Otherwise, convert the vector to a flat vector type. - unsigned sum = 1; - for (unsigned i = 0; i < rank; i++) { - sum *= type.getShape()[i]; - } + int64_t sum = + std::accumulate(type.getShape().begin(), type.getShape().end(), + int64_t{1}, std::multiplies()); return VectorType::get(sum, elemType); }); typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type { From 7b6bdbea7ed0937c8652a717c1f638f4bd545065 Mon Sep 17 00:00:00 2001 From: "Lee, Sang Ik" Date: Wed, 27 Aug 2025 22:54:32 +0000 Subject: [PATCH 18/18] Update incorrect header. --- mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 906b943a98756..d8dd09a6280c0 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -1,4 +1,4 @@ -//===-- XeGPUToXeVM.cpp - XeVM to LLVM dialect conversion -------*- C++ -*-===// +//===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information.