diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td index eb05628d4772b..e42799689e490 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td @@ -85,4 +85,16 @@ def XeGPUVectorLinearize : Pass<"xegpu-vector-linearize"> { "scf::SCFDialect", "ub::UBDialect", "vector::VectorDialect"]; } +def XeGPUOptimizeBlockLoads : Pass<"xegpu-optimize-block-loads"> { + let summary = "Optimize XeGPU block load operations"; + let description = [{ + This pass rewrites XeGPU loadNd operations into more optimal forms + to improve performance. This includes, + - Rewriting transpose B loads into more optimal forms to use HW block + transpose instructions for better performance. + }]; + let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect", + "vector::VectorDialect"]; +} + #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h index a480195eebd00..1776a209d0bf1 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h @@ -61,7 +61,8 @@ struct UnrollOptions { /// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`. void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns); - +/// Appends patterns for optimizing block load operations into `patterns`. +void populateXeGPUOptimizeBlockLoadsPatterns(RewritePatternSet &patterns); /// Appends patterns for XeGPU SIMT distribution into `patterns`. void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns); /// Appends patterns for moving function body into gpu.warp_execute_on_lane0 op. diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h index 620a2fe43d682..58092c3bb9ed2 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h @@ -166,6 +166,15 @@ SmallVector addElementwise(OpBuilder &builder, Location loc, SmallVector addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef lhs, ArrayRef rhs); + +/// Helper Function to find a proper instruction multiple for the user-supplied +/// sg-level data shape (diven by `dim`). `candidates` are uArch allowed shapes. +/// `candidateMultiples` are uArch multiples of such shapes (i.e. block count or +/// array length). +template +int getLargestDivisor(T dim, ArrayRef candidates, + ArrayRef candidateMultiples = {}); + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt index e6f76067094ce..29b645feab2c6 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms XeGPUWgToSgDistribute.cpp XeGPUPropagateLayout.cpp XeGPUVectorLinearize.cpp + XeGPUOptimizeBlockLoads.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp new file mode 100644 index 0000000000000..4dc5ea4f7bb24 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp @@ -0,0 +1,490 @@ +//===- XeGPUOptimizeBlockLoads.cpp - XeGPU optimize block loads -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Transforms/Passes.h" +#include "mlir/Dialect/XeGPU/Transforms/Transforms.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" +#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h" +#include "mlir/Dialect/XeGPU/uArch/uArchBase.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include + +namespace mlir { +namespace xegpu { +#define GEN_PASS_DEF_XEGPUOPTIMIZEBLOCKLOADS +#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" +} // namespace xegpu +} // namespace mlir + +#define DEBUG_TYPE "xegpu-optimize-block-loads" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +using namespace mlir; + +namespace { + +/// Get the 2D lane data from a tensor desc type if it exists. +static std::optional> +getMaybeLaneData(xegpu::TensorDescType tdescType) { + auto layout = tdescType.getLayoutAttr(); + if (!layout) + return std::nullopt; + auto laneData = layout.getEffectiveLaneDataAsInt(); + if (laneData.size() != 2) + return std::nullopt; + return laneData; +} + +/// Get the 2D lane layout from a tensor desc type if it exists. +static std::optional> +getMaybeLaneLayout(xegpu::TensorDescType tdescType) { + auto layout = tdescType.getLayoutAttr(); + if (!layout) + return std::nullopt; + auto laneLayout = layout.getEffectiveLaneLayoutAsInt(); + if (laneLayout.size() != 2) + return std::nullopt; + return laneLayout; +} + +/// A layout can be optimized if its lane layout is transposed (lane[0] != 1 && +/// lane[1] == 1), but inner lane data is not equal to [1, 1]. +/// Example: +/// !xegpu.tensor_desc<16x16xf16, +/// #xegpu.layout> +/// In this case, lane layout is transposed (from the usual [1, SG_SIZE] form) +/// indicating that this is a load that requires transpose effect. However, +/// lane data is [1, 2], meaning that each lane must grab 2 f16 elements from +/// the inner dimension. We convert this to a optimized form by converting the +/// tensor_desc to i32 type such that lane data becomes [1, 1]. This makes the +/// later lowering easily use the load with transpose instruction. +static bool canBeOptimizedForTranspose(ArrayRef laneLayout, + ArrayRef laneData) { + if (laneLayout.size() != 2 || laneData.size() != 2) + return false; + if (laneLayout[0] == 1 || laneLayout[1] != 1) + return false; + if (laneData[0] != 1 || laneData[1] == 1) + return false; + return true; +} + +/// A tensor desc type can be optimized if its element type is less than 32 bits +/// and its layout can be optimized. +static bool canBeOptimizedForTranspose(xegpu::TensorDescType tdescType) { + // If the dtype is greater or equal to 32 bits, layout must be valid. + int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth(); + if (elementTyBitwidth >= 32) + return false; + auto maybeLaneLayout = getMaybeLaneLayout(tdescType); + auto maybeLaneData = getMaybeLaneData(tdescType); + if (!maybeLaneData || !maybeLaneLayout) + return false; + return canBeOptimizedForTranspose(*maybeLaneLayout, *maybeLaneData); +} + +/// Check if a tensor desc type can be optimized for transpose, if so return the +/// new optimized tensor desc type with a valid transpose layout. +static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType, + const uArch *targetuArch) { + if (!canBeOptimizedForTranspose(tdescType)) + return tdescType; + auto laneData = getMaybeLaneData(tdescType) + .value(); // Lane data must exist if we reach here. + int64_t innerLaneData = laneData[1]; + int elementTyBitwidth = tdescType.getElementType().getIntOrFloatBitWidth(); + // Required shape is total shape of the vector result that this tensor desc + // must eventually load after adjusting for the new bitwidth and array + // length. + SmallVector requiredShape(tdescType.getShape()); + requiredShape.back() = + requiredShape.back() * tdescType.getArrayLength() / innerLaneData; + int newBitWidth = elementTyBitwidth * innerLaneData; + Type newElemTy = IntegerType::get(tdescType.getContext(), newBitWidth); + // Supported shape is the max transpose shape that can be supported by + // hardware that is less than or equal to required shape. + auto *blockLoadTarget = dyn_cast( + targetuArch->getInstruction(InstructionKind::Subgroup2DBlockLoad)); + auto maybeHWParams = blockLoadTarget->getBlockWidthHeightCount( + newElemTy, /** has transform */ false, /** has transpose */ true); + // If no HW params found, return the original type. + if (!maybeHWParams) + return tdescType; + auto [widths, heights, counts] = maybeHWParams.value(); + // TODO: Currently we expect array length to be 1 for transpose case. + if (counts.size() != 1 || counts[0] != 1) + return tdescType; + int arrayLen = counts[0]; + int supportedHeight = + xegpu::getLargestDivisor(static_cast(requiredShape[0]), heights); + int supportedWidth = + xegpu::getLargestDivisor(static_cast(requiredShape[1]), widths); + // If no supported height or width found, return the original type. + if (supportedHeight == -1 || supportedWidth == -1) + return tdescType; + + SmallVector supportedShape = {supportedHeight, supportedWidth}; + xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get( + tdescType.getContext(), + tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1}); + // Array length can not be larger than 1 for transpose case. + return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen, + tdescType.getBoundaryCheck(), + tdescType.getMemorySpace(), newLayout); +} + +/// Helper to convert an OpFoldResult to Value. +static Value convertToValue(ConversionPatternRewriter &rewriter, Location loc, + OpFoldResult ofr) { + std::optional mayBeInt = getConstantIntValue(ofr); + if (mayBeInt) + return arith::ConstantIndexOp::create(rewriter, loc, *mayBeInt).getResult(); + return llvm::cast(ofr); +} + +/// Helper to divide a Value by a constant integer. +static Value divideByConstant(ConversionPatternRewriter &rewriter, Location loc, + Value val, int64_t constant) { + // If the constant is a power of 2, use right shift for division. + if (llvm::isPowerOf2_64(constant)) { + int64_t shiftAmount = llvm::Log2_64(constant); + return arith::ShRUIOp::create( + rewriter, loc, val, + arith::ConstantIndexOp::create(rewriter, loc, shiftAmount) + .getResult()) + .getResult(); + } + auto constantOp = + arith::ConstantIndexOp::create(rewriter, loc, constant).getResult(); + return arith::DivUIOp::create(rewriter, loc, val, constantOp).getResult(); +} + +/// This function takes a larger register block `data` and generates multiple +/// smaller loads (size given by `newTensorDesc`) to fill in the `data` block +/// starting from `offsets`. +static Value generateLoads(ConversionPatternRewriter &rewriter, + TypedValue data, + SmallVector offsets, + TypedValue newTensorDesc, + xegpu::LoadNdOp origLoadOp) { + Location loc = data.getLoc(); + assert(offsets.size() >= 2 && "Expecting at least 2 offsets for 2D LoadNdOp"); + Value offsetDim0 = convertToValue(rewriter, loc, offsets[offsets.size() - 2]); + Value offsetDim1 = convertToValue(rewriter, loc, offsets[offsets.size() - 1]); + SmallVector supportedShape(newTensorDesc.getType().getShape()); + // Compute the ratio between original shape and supported shape. We need to + // generate loads in this ratio arrangement. + auto shapeRatio = computeShapeRatio(data.getType().getShape(), + supportedShape) + .value(); // `ratio` must be defined if we reach here. + for (int64_t h = 0; h < shapeRatio[0]; ++h) { + for (int64_t w = 0; w < shapeRatio[1]; ++w) { + int64_t localOffsetDim0 = h * supportedShape[0]; + int64_t localOffsetDim1 = w * supportedShape[1]; + Value loadOffsetX = arith::AddIOp::create( + rewriter, loc, offsetDim0, + arith::ConstantIndexOp::create(rewriter, loc, localOffsetDim0) + .getResult()); + Value loadOffsetY = arith::AddIOp::create( + rewriter, loc, offsetDim1, + arith::ConstantIndexOp::create(rewriter, loc, localOffsetDim1) + .getResult()); + auto loadOp = xegpu::LoadNdOp::create( + rewriter, loc, + VectorType::get(supportedShape, data.getType().getElementType()), + newTensorDesc, ArrayRef{loadOffsetX, loadOffsetY}, + origLoadOp.getPackedAttr(), origLoadOp.getTransposeAttr(), + origLoadOp.getL1HintAttr(), origLoadOp.getL2HintAttr(), + origLoadOp.getL3HintAttr()); + // Set the layout for the loadOp. + auto layoutAttr = newTensorDesc.getType().getLayoutAttr(); + xegpu::setDistributeLayoutAttr(loadOp->getOpResult(0), layoutAttr); + // Insert the loaded block into the right position in data. + auto insertOp = vector::InsertStridedSliceOp::create( + rewriter, loc, loadOp.getResult(), data, + ArrayRef{localOffsetDim0, localOffsetDim1}, + ArrayRef{1, 1}); + // InsertOp must have the same layout as newTensorDesc. + xegpu::setDistributeLayoutAttr(insertOp->getOpResult(0), layoutAttr); + data = insertOp.getResult(); + } + } + return data; +} + +/// Checks if a CreateNdDescOp can be optimized for transpose, if so creates a +/// new CreateNdDescOp with optimized tensor desc type. This involves extracting +/// the base pointer from the original memory source and adjusting the shape and +/// strides of the tensor desc to fit with the new optimized transpose layout. +class XeGPUCreateNdDescOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::CreateNdDescOp createNdOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tdescTy = createNdOp.getType(); + // Get the target uArch info. + auto chipStr = xegpu::getChipStr(createNdOp); + // Check if the chip is supported. + assert( + chipStr && (chipStr.value() == "pvc" || chipStr.value() == "bmg") && + "Expecting target chip to be pvc or bmg for transpose optimization."); + const uArch *targetuArch = xegpu::uArch::getUArch(chipStr.value()); + + auto convertType = tryOptimize(tdescTy, targetuArch); + if (convertType == tdescTy) + return failure(); + auto strides = createNdOp.getMixedStrides(); + auto maybeConstInnerStride = getConstantIntValue(strides.back()); + // Only row-major memrefs are expected for now. + if (!maybeConstInnerStride || *maybeConstInnerStride != 1) + return rewriter.notifyMatchFailure( + createNdOp, "Expecting row-major memref for transpose optimization."); + Value source = createNdOp.getSource(); + auto optionalLaneData = getMaybeLaneData(tdescTy); + assert(optionalLaneData && "Expected 2D lane data"); + auto laneData = optionalLaneData.value(); + int64_t innerLaneData = laneData[1]; + auto memrefType = dyn_cast(source.getType()); + // Inner dimension of the shape must be adjusted based on innerLaneData. + SmallVector modifiedShape(createNdOp.getMixedSizes()); + modifiedShape.back() = divideByConstant( + rewriter, createNdOp.getLoc(), + convertToValue(rewriter, createNdOp.getLoc(), modifiedShape.back()), + innerLaneData); + // Similarly, second to last stride must be adjusted. + assert(strides.size() >= 2 && + "Expected at least 2 strides for CreateNdDescOp"); + SmallVector modifiedStrides(strides); + modifiedStrides[modifiedStrides.size() - 2] = divideByConstant( + rewriter, createNdOp.getLoc(), + convertToValue(rewriter, createNdOp.getLoc(), + modifiedStrides[modifiedStrides.size() - 2]), + innerLaneData); + + // If the source is a static memref, we need to extract the pointer to + // base address. + if (memrefType && memrefType.hasStaticShape()) { + auto extractOp = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, createNdOp.getLoc(), source); + source = arith::IndexCastOp::create(rewriter, createNdOp.getLoc(), + rewriter.getI64Type(), + extractOp.getResult()) + .getResult(); + } + // Create a new CreateNdDescOp with the modified shape and converted type. + auto newCreateNdDescOp = xegpu::CreateNdDescOp::create( + rewriter, createNdOp.getLoc(), convertType, source, modifiedShape, + modifiedStrides); + rewriter.replaceOp(createNdOp, newCreateNdDescOp.getResult()); + return success(); + } +}; + +/// Checks if a LoadNdOp consumes a tensor desc type that was rewritten for +/// tranpose optimization. If so, rewrites the LoadNdOp to to align with the +/// adjusted tensor desc type. This can result in multiple LoadNdOps being +/// generated to fill in the original load shape. +class XeGPULoadNdDescOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadNdOp loadNdOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto origTensorDescType = loadNdOp.getTensorDescType(); + auto adaptorType = + cast(adaptor.getTensorDesc().getType()); + if (adaptorType == origTensorDescType) + return failure(); + // Offsets must be adjusted based on innerLaneData. + auto laneData = getMaybeLaneData(loadNdOp.getTensorDescType()).value(); + int64_t innerLaneData = laneData[1]; + auto offsets = loadNdOp.getMixedOffsets(); + if (offsets.empty()) + return rewriter.notifyMatchFailure(loadNdOp, + "Expecting offsets in LoadNd"); + SmallVector modifiedOffsets(offsets); + modifiedOffsets.back() = divideByConstant( + rewriter, loadNdOp.getLoc(), + convertToValue(rewriter, loadNdOp.getLoc(), modifiedOffsets.back()), + innerLaneData); + // Get the 2D data shape of this loadNdOp in its original type including + // array length. + SmallVector origDataShape(origTensorDescType.getShape()); + // Adjust the data shape based on innerLaneData. + origDataShape.back() /= innerLaneData; + // HW supported shape is the new tensor desc shape after conversion. + SmallVector hwSupportedShape(adaptorType.getShape()); + VectorType origVectorType = + VectorType::get(origDataShape, adaptorType.getElementType()); + Value data; + // Orig data shape is 3D for the array length case. + if (origTensorDescType.getArrayLength() > 1) { + SmallVector arraySlices; + for (int64_t i = 0; i < origTensorDescType.getArrayLength(); ++i) { + Value slice = arith::ConstantOp::create( + rewriter, loadNdOp->getLoc(), origVectorType, + rewriter.getZeroAttr(origVectorType)); + // Increase the Y offset for each array slice. + Value offsetY = convertToValue(rewriter, loadNdOp->getLoc(), + modifiedOffsets.back()); + modifiedOffsets.back() = + arith::AddIOp::create( + rewriter, loadNdOp->getLoc(), offsetY, + arith::ConstantIndexOp::create(rewriter, loadNdOp->getLoc(), + i * origDataShape[1]) + .getResult()) + .getResult(); + slice = generateLoads( + rewriter, cast>(slice), modifiedOffsets, + cast>(adaptor.getTensorDesc()), + loadNdOp); + // BitCast back to original load shape without array length. + auto bitcastType = VectorType::get(origTensorDescType.getShape(), + origTensorDescType.getElementType()); + auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(), + bitcastType, slice); + // BitCastOp must have the same layout as the original loadNdOp. + xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0), + origTensorDescType.getLayoutAttr()); + arraySlices.push_back(bitCastOp.getResult()); + } + rewriter.replaceOpWithMultiple(loadNdOp, {arraySlices}); + return success(); + } + data = arith::ConstantOp::create( + rewriter, loadNdOp->getLoc(), + VectorType::get(origDataShape, adaptorType.getElementType()), + rewriter.getZeroAttr(origVectorType)); + data = generateLoads( + rewriter, cast>(data), modifiedOffsets, + cast>(adaptor.getTensorDesc()), + loadNdOp); + auto bitCastOp = vector::BitCastOp::create(rewriter, loadNdOp->getLoc(), + loadNdOp.getType(), data); + // BitCastOp must have the same layout as the original loadNdOp. + xegpu::setDistributeLayoutAttr(bitCastOp->getOpResult(0), + origTensorDescType.getLayoutAttr()); + rewriter.replaceOp(loadNdOp, bitCastOp); + return success(); + } +}; + +/// Vector ExtractOp must be processed if the original tensor desc type has +/// array length greater than 1. In this case, the LoadNdOp is replaced with +/// multiple LoadNdOps for each array slice making the extraction unnecessary. +/// In this case, we simply remove the ExtractOp. +class VectorExtractOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(vector::ExtractOp extractOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Check if the source of the extraction is split to multiple values. + if (adaptor.getSource().size() == 1) + return failure(); + auto mixedPos = extractOp.getMixedPosition(); + if (mixedPos.size() != 1) + return failure(); + auto mayBeInt = getConstantIntValue(mixedPos[0]); + if (!mayBeInt) + return failure(); + rewriter.replaceOp(extractOp, adaptor.getSource()[*mayBeInt]); + return success(); + } +}; + +} // namespace + +void xegpu::populateXeGPUOptimizeBlockLoadsPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +namespace { + +struct XeGPUOptimizeBlockLoadsPass final + : public xegpu::impl::XeGPUOptimizeBlockLoadsBase< + XeGPUOptimizeBlockLoadsPass> { + void runOnOperation() override { + MLIRContext &context = getContext(); + TypeConverter converter; + RewritePatternSet patterns(&context); + ConversionTarget target(context); + + // This pass is only meant for PVC and BMG targets. If unsupported target + // is found, exit early. + bool isTargetSupported = false; + getOperation()->walk([&](gpu::GPUFuncOp funcOp) { + auto chipStr = xegpu::getChipStr(funcOp); + if (chipStr && (chipStr.value() == "pvc" || chipStr.value() == "bmg")) + isTargetSupported = true; + }); + + if (!isTargetSupported) { + DBGS() << "XeGPUOptimizeBlockLoadsPass only supports PVC and BMG targets." + << "\n"; + return; + } + + // CreateNdDescOp and LoadNdOp with optimizable tensor desc types must be + // converted. + target.addDynamicallyLegalOp( + [&](xegpu::CreateNdDescOp createNdOp) { + return !canBeOptimizedForTranspose(createNdOp.getType()); + }); + target.addDynamicallyLegalOp( + [&](xegpu::LoadNdOp loadNdOp) { + return !canBeOptimizedForTranspose(loadNdOp.getTensorDescType()); + }); + // Vector ExtractOps can have optimizable layouts if they extract from + // LoadNdOps with array length greater than 1. These ExtractOps must be + // converted. + target.addDynamicallyLegalOp( + [&](vector::ExtractOp extractOp) { + auto layout = xegpu::getDistributeLayoutAttr(extractOp.getResult()); + if (!layout) + return true; + auto laneLayout = layout.getEffectiveLaneLayoutAsInt(); + auto laneData = layout.getEffectiveLaneDataAsInt(); + return !canBeOptimizedForTranspose(laneLayout, laneData); + }); + converter.addConversion([](Type type) { return type; }); + + target.addLegalDialect(); + scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns, + target); + xegpu::populateXeGPUOptimizeBlockLoadsPatterns(patterns); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + DBGS() << "Optimize block loads pass failed.\n"; + return signalPassFailure(); + } + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index 14c49e7f45706..4e1a539771d2f 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -204,28 +204,6 @@ struct LayoutInfoLattice : public Lattice { using Lattice::Lattice; }; -/// Helper Function to find a proper instruction multiple for the user-supplied -/// sg-level data shape. `candidates` are uArch allowed shapes. -/// `candidateMultiples` are uArch multiples of such shapes (e.g., block count). -template -int getLargestDivisor(T dim, ArrayRef candidates, - ArrayRef candidateMultiples = {}) { - static_assert(std::is_integral::value, "T must be an integer type"); - int largest = -1; - SmallVector multiples = {1}; - if (!candidateMultiples.empty()) - multiples = - SmallVector(candidateMultiples.begin(), candidateMultiples.end()); - for (T candidate : candidates) { - for (T multiple : multiples) { - int value = static_cast(candidate * multiple); - if (value != 0 && dim % value == 0 && value > largest) - largest = value; - } - } - return largest; -} - /// Helper Functions to get default layouts. A `default layout` is a layout that /// is assigned to a value when the layout is not fixed by some anchor operation /// (like DPAS). @@ -505,7 +483,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp( prefetch.emitWarning("No known block params found for the element type."); auto [bWidth, bHeight, bCount] = blockWHC.value(); SmallVector instData; - int instWidth = getLargestDivisor( + int instWidth = xegpu::getLargestDivisor( static_cast(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth, bCount); if (instWidth == -1) @@ -514,7 +492,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp( if (tdescTy.getRank() == 1) instData = {instWidth}; else { - int instHeight = getLargestDivisor( + int instHeight = xegpu::getLargestDivisor( static_cast(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight); if (instHeight == -1) prefetch.emitWarning( @@ -634,7 +612,7 @@ void LayoutInfoPropagation::visitDpasOp( const unsigned dataALen = aTy.getShape().front(); auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType()); const int maxALen = - getLargestDivisor(dataALen, ArrayRef(supportedALen)); + xegpu::getLargestDivisor(dataALen, ArrayRef(supportedALen)); if (maxALen == -1) dpas.emitWarning( "No suitable instruction multiple found for the given shape."); @@ -642,7 +620,7 @@ void LayoutInfoPropagation::visitDpasOp( const unsigned dataBLen = bTy.getShape().back(); auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType()); const int maxBLen = - getLargestDivisor(dataBLen, ArrayRef(supportedBLen)); + xegpu::getLargestDivisor(dataBLen, ArrayRef(supportedBLen)); if (maxBLen == -1) dpas.emitWarning( "No suitable instruction multiple found for the given shape."); @@ -662,7 +640,7 @@ void LayoutInfoPropagation::visitDpasOp( const unsigned dataCLen = bTy.getShape().back(); auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType()); const int maxCLen = - getLargestDivisor(dataCLen, ArrayRef(supportedCLen)); + xegpu::getLargestDivisor(dataCLen, ArrayRef(supportedCLen)); if (maxCLen == -1) dpas.emitWarning( "No suitable instruction multiple found for the given shape."); @@ -691,7 +669,7 @@ void LayoutInfoPropagation::visitStoreNdOp( store.emitWarning("No known block params found for the element type."); auto [bWidth, bHeight, bCount] = blockWHC.value(); SmallVector instData; - int instWidth = getLargestDivisor( + int instWidth = xegpu::getLargestDivisor( static_cast(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth, bCount); if (instWidth == -1) @@ -700,7 +678,7 @@ void LayoutInfoPropagation::visitStoreNdOp( if (dataTy.getRank() == 1) instData = {instWidth}; else { - int instHeight = getLargestDivisor( + int instHeight = xegpu::getLargestDivisor( static_cast(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight); if (instHeight == -1) store.emitWarning( diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index d575a415a3035..de9e09d427665 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -555,3 +555,29 @@ xegpu::addWithRightAligned(OpBuilder &builder, Location loc, results.append(addElementwise(builder, loc, a, b)); return results; } + +template +int xegpu::getLargestDivisor(T dim, ArrayRef candidates, + ArrayRef candidateMultiples) { + static_assert(std::is_integral::value, "T must be an integer type"); + int largest = -1; + SmallVector multiples = {1}; + if (!candidateMultiples.empty()) + multiples = + SmallVector(candidateMultiples.begin(), candidateMultiples.end()); + for (T candidate : candidates) { + for (T multiple : multiples) { + int value = static_cast(candidate * multiple); + if (value != 0 && dim % value == 0 && value > largest) + largest = value; + } + } + return largest; +} + +/// Explicit instantiations +template int xegpu::getLargestDivisor(int dim, ArrayRef candidates, + ArrayRef candidateMultiples); +template int +xegpu::getLargestDivisor(unsigned dim, ArrayRef candidates, + ArrayRef candidateMultiples); diff --git a/mlir/test/Dialect/XeGPU/optimize-transpose.mlir b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir new file mode 100644 index 0000000000000..24a0de6ed48a5 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/optimize-transpose.mlir @@ -0,0 +1,280 @@ +// RUN: mlir-opt --xevm-attach-target='module=xevm_* chip=pvc' \ +// RUN: --xegpu-optimize-block-loads --canonicalize --cse --split-input-file %s | FileCheck %s + +// CHECK-LABEL: gpu.func @no_scf( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<64x64xf16>, %{{.*}}: vector<8x16xf16>) -> vector<8x16xf32> { +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x64xf16> -> index +// CHECK: %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64 +// CHECK: %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C32]]], strides : [%[[C32]], 1] : i64 +// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout> +// CHECK-NEXT: %[[B:.*]] = xegpu.load_nd %[[BDESC]][%{{.*}}, %[[C16]]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout} +// CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout> -> vector<16x8xi32> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[B]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout} : vector<16x8xi32> to vector<16x16xf16> +#a = #xegpu.layout +#b = #xegpu.layout +#bt = #xegpu.layout +gpu.module @xevm_module { +gpu.func @no_scf(%arg0: memref<64x64xf16>, %arg1: vector<8x16xf16>) -> vector<8x16xf32> { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16> -> !xegpu.tensor_desc<16x16xf16, #b> + %1 = xegpu.load_nd %0[%c0, %c32] { result_layout = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16> + %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #a } : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + gpu.return %6 : vector<8x16xf32> +} +} + +// ----- +// CHECK-LABEL: gpu.func @no_scf_i8( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<64x64xi8>, %{{.*}}: vector<8x32xi8>) -> vector<8x16xi32> { +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x64xi8> -> index +// CHECK: %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64 +// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C16]]], strides : [%[[C16]], 1] : i64 +// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout> +// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T1]][%{{.*}}, %[[C16]]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout} +// CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout> -> vector<16x8xi32> +// CHECK: %[[T3:.*]] = vector.bitcast %[[T2]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout} : vector<16x8xi32> to vector<16x32xi8> +#a = #xegpu.layout +#b = #xegpu.layout +#bt = #xegpu.layout +#c = #xegpu.layout +gpu.module @xevm_module { +gpu.func @no_scf_i8(%arg0: memref<64x64xi8>, %arg1: vector<8x32xi8>) -> vector<8x16xi32> { + %c0 = arith.constant 0 : index + %c64 = arith.constant 64 : index + %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xi8> -> !xegpu.tensor_desc<16x32xi8, #b> + %1 = xegpu.load_nd %0[%c0, %c64] { result_layout = #b } : !xegpu.tensor_desc<16x32xi8, #b> -> vector<16x32xi8> + %2 = vector.transpose %1, [1, 0] { layout_result_0 = #bt } : vector<16x32xi8> to vector<32x16xi8> + %6 = xegpu.dpas %arg1, %2 { layout_result_0 = #c } : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32> + gpu.return %6 : vector<8x16xi32> +} +} + + +// ----- +// CHECK-LABEL: gpu.func @gemm_b_transpose( +// CHECK-SAME: %{{.*}} memref<256x256xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) { +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C256:.*]] = arith.constant 256 : index +// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index +// CHECK: %[[T3:.*]] = arith.index_cast %[[PTR]] : index to i64 +// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[T3]], shape : [256, %[[C128]]], strides : [%c128, 1] +// CHECK-SAME: : i64 -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout> +// CHECK: %{{.*}} = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>) { +// CHECK: %[[T7:.*]] = arith.shrui %[[K]], %[[C1]] : index +// CHECK-NEXT: %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout} : +// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout> -> vector<16x8xi32> +// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout} +// CHECK-SAME: : vector<16x8xi32> to vector<16x16xf16> +#a = #xegpu.layout +#b = #xegpu.layout +#bt = #xegpu.layout +gpu.module @xevm_module { +gpu.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a> + %1 = xegpu.load_nd %0[%c0, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32> + %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a> + %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b> + %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) { + %5 = xegpu.load_nd %2[%c0, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16> + %6 = xegpu.load_nd %3[%c0, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16> + %7 = vector.transpose %6, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %8 : vector<8x16xf32> + } {layout_result_0 = #a} + xegpu.store_nd %4, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + gpu.return +} +} + +// ----- +// CHECK-LABEL: gpu.func @nested_scf( +// CHECK-SAME: %{{.*}}: memref<256x256xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) { +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[C256:.*]] = arith.constant 256 : index +// CHECK: scf.for %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index +// CHECK: %[[T3:.*]] = arith.index_cast %[[PTR]] : index to i64 +// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[T3]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64 +// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout> +// CHECK: %{{.*}} = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>) { +// CHECK: %[[T7:.*]] = arith.shrui %[[K]], %[[C1]] : index +// CHECK-NEXT: %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]] {layout_result_0 = #xegpu.layout< +// CHECK-SAME: lane_layout = [16, 1], lane_data = [1, 1]>} : +// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout> -> vector<16x8xi32> +// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout} +// CHECK-SAME: : vector<16x8xi32> to vector<16x16xf16> +#a = #xegpu.layout +#b = #xegpu.layout +#bt = #xegpu.layout +gpu.module @xevm_module { +gpu.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + scf.for %arg8 = %c0 to %c256 step %c16 { + %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a> + %1 = xegpu.load_nd %0[%arg8, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32> + %2 = xegpu.create_nd_tdesc %arg0 : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #a> + %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #b> + %4 = scf.for %arg3 = %c0 to %c256 step %c16 iter_args(%arg4 = %1) -> (vector<8x16xf32>) { + %5 = xegpu.load_nd %2[%arg8, %arg3] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf16, #a> -> vector<8x16xf16> + %6 = xegpu.load_nd %3[%arg8, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<16x16xf16, #b> -> vector<16x16xf16> + %7 = vector.transpose %6, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %8 = xegpu.dpas %5, %7, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %8 : vector<8x16xf32> + } {layout_result_0 = #a} + xegpu.store_nd %4, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + } + gpu.return +} +} + +// ----- +// CHECK-LABEL: gpu.func @large_loads( +// CHECK-SAME: %{{.*}}: vector<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %{{.*}}: memref<256x256xf32>) { +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<32x16xi32> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index +// CHECK: %[[T2:.*]] = arith.index_cast %[[PTR]] : index to i64 +// CHECK: %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64 +// CHECK-SAME: -> !xegpu.tensor_desc<32x8xi32, #xegpu.layout> +// CHECK: %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { +// CHECK: %[[T5:.*]] = arith.shrui %[[K]], %[[C1]] : index +// CHECK: %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]] {layout_result_0 = #xegpu.layout} +// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -> vector<32x8xi32> +// CHECK: %[[T7:.*]] = vector.insert_strided_slice %[[T6]], %[[CST]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout, offsets = [0, 0], strides = [1, 1]} +// CHECK-SAME: : vector<32x8xi32> into vector<32x16xi32> +// CHECK: %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index +// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] {layout_result_0 = #xegpu.layout} +// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -> vector<32x8xi32> +// CHECK: %[[T10:.*]] = vector.insert_strided_slice %[[T9]], %[[T7]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout, offsets = [0, 8], strides = [1, 1]} +// CHECK-SAME: : vector<32x8xi32> into vector<32x16xi32> +// CHECK: %{{.*}} = vector.bitcast %[[T10]] {layout_result_0 = #xegpu.layout} +// CHECK-SAME: : vector<32x16xi32> to vector<32x32xf16> +#a = #xegpu.layout +#b = #xegpu.layout +#bt = #xegpu.layout +gpu.module @xevm_module { +gpu.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a> + %1 = xegpu.load_nd %0[%c0, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32> + %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> -> !xegpu.tensor_desc<32x32xf16, #b> + %4:4 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %1, %arg5 = %1, %arg6 = %1, %arg7 = %1) + -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { + %6 = xegpu.load_nd %3[%c0, %arg3] { layout_result_0 = #b } : !xegpu.tensor_desc<32x32xf16, #b> -> vector<32x32xf16> + %7 = vector.extract_strided_slice %6 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b } + : vector<32x32xf16> to vector<16x16xf16> + %8 = vector.extract_strided_slice %6 {offsets = [0, 16], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b } + : vector<32x32xf16> to vector<16x16xf16> + %9 = vector.extract_strided_slice %6 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b } + : vector<32x32xf16> to vector<16x16xf16> + %10 = vector.extract_strided_slice %6 {offsets = [16, 16], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b } + : vector<32x32xf16> to vector<16x16xf16> + %11 = vector.transpose %7, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %12 = vector.transpose %8, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %13 = vector.transpose %9, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %14 = vector.transpose %10, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %15 = xegpu.dpas %arg0, %11, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %16 = xegpu.dpas %arg0, %12, %arg5 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %17 = xegpu.dpas %arg0, %13, %arg6 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %18 = xegpu.dpas %arg0, %14, %arg7 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %15, %16, %17, %18 : vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32> + } {layout_result_0 = #a, layout_result_1 = #a, layout_result_2 = #a, layout_result_3 = #a} + xegpu.store_nd %4#0, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + xegpu.store_nd %4#1, %0[%c0, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + xegpu.store_nd %4#2, %0[%c16, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + xegpu.store_nd %4#3, %0[%c16, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + gpu.return +} +} + +// ----- +// CHECK-LABEL: gpu.func @array_length( +// CHECK-SAME: %{{.*}}: vector<8x16xf16>, %[[ARG1:[a-zA-Z0-9]+]]: memref<256x256xf16>, %arg2: memref<256x256xf32>) { +// CHECK: %[[C128:.*]] = arith.constant 128 : index +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index +// CHECK: %[[T2:.*]] = arith.index_cast %[[PTR]] : index to i64 +// CHECK: %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64 -> +// CHECK-SAME: !xegpu.tensor_desc<32x8xi32, #xegpu.layout> +// CHECK: %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { +// CHECK: %[[T5:.*]] = arith.shrui %[[K]], %[[C1]] : index +// CHECK: %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]] {layout_result_0 = #xegpu.layout} +// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -> vector<32x8xi32> +// CHECK: %[[T7:.*]] = vector.bitcast %[[T6]] {layout_result_0 = #xegpu.layout} +// CHECK-SAME: : vector<32x8xi32> to vector<32x16xf16> +// CHECK: %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index +// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] {layout_result_0 = #xegpu.layout} +// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -> vector<32x8xi32> +// CHECK: %[[T10:.*]] = vector.bitcast %[[T9]] {layout_result_0 = #xegpu.layout} +// CHECK-SAME: : vector<32x8xi32> to vector<32x16xf16> +#a = #xegpu.layout +#b = #xegpu.layout +#bt = #xegpu.layout +gpu.module @xevm_module { +gpu.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %0 = xegpu.create_nd_tdesc %arg2 : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #a> + %1 = xegpu.load_nd %0[%c0, %c0] { layout_result_0 = #a } : !xegpu.tensor_desc<8x16xf32, #a> -> vector<8x16xf32> + %3 = xegpu.create_nd_tdesc %arg1 : memref<256x256xf16> + -> !xegpu.tensor_desc<32x16xf16, #b, #xegpu.block_tdesc_attr> + %4:4 = scf.for %arg3 = %c0 to %c256 step %c32 iter_args(%arg4 = %1, %arg5 = %1, %arg6 = %1, %arg7 = %1) + -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { + %6 = xegpu.load_nd %3[%c0, %arg3] { layout_result_0 = #b } + : !xegpu.tensor_desc<32x16xf16, #b, #xegpu.block_tdesc_attr> -> vector<2x32x16xf16> + %19 = vector.extract %6[0] { layout_result_0 = #b } : vector<32x16xf16> from vector<2x32x16xf16> + %20 = vector.extract %6[1] { layout_result_0 = #b } : vector<32x16xf16> from vector<2x32x16xf16> + %7 = vector.extract_strided_slice %19 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b } + : vector<32x16xf16> to vector<16x16xf16> + %8 = vector.extract_strided_slice %19 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b } + : vector<32x16xf16> to vector<16x16xf16> + %9 = vector.extract_strided_slice %20 {offsets = [0, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b } + : vector<32x16xf16> to vector<16x16xf16> + %10 = vector.extract_strided_slice %20 {offsets = [16, 0], sizes = [16, 16], strides = [1, 1], layout_result_0 = #b } + : vector<32x16xf16> to vector<16x16xf16> + %11 = vector.transpose %7, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %12 = vector.transpose %8, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %13 = vector.transpose %9, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %14 = vector.transpose %10, [1, 0] { layout_result_0 = #bt } : vector<16x16xf16> to vector<16x16xf16> + %15 = xegpu.dpas %arg0, %11, %arg4 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %16 = xegpu.dpas %arg0, %12, %arg5 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %17 = xegpu.dpas %arg0, %13, %arg6 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %18 = xegpu.dpas %arg0, %14, %arg7 {layout_result_0 = #a} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + scf.yield %15, %16, %17, %18 : vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32> + } {layout_result_0 = #a, layout_result_1 = #a, layout_result_2 = #a, layout_result_3 = #a} + xegpu.store_nd %4#0, %0[%c0, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + xegpu.store_nd %4#1, %0[%c0, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + xegpu.store_nd %4#2, %0[%c16, %c0] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + xegpu.store_nd %4#3, %0[%c16, %c16] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #a> + gpu.return +} +}