diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td index 1ecd6ce95322b..3e81f2d0ed786 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td @@ -23,4 +23,19 @@ def XeGPUFoldAliasOps : Pass<"xegpu-fold-alias-ops"> { ]; } +def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> { + let summary = "Distribute XeGPU ops to work items"; + let description = [{ + The pass distributes subgroup level (SIMD) XeGPU ops to work items. + }]; + let dependentDialects = [ + "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect" + ]; + let options = [ + Option<"printOnly", "print-analysis-only", "bool", + /*default=*/"false", + "Print the result of the subgroup map propagation analysis and exit."> + ]; +} + #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt index 7fb64d3b97b87..124e904edb543 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRXeGPUTransforms XeGPUFoldAliasOps.cpp + XeGPUSubgroupDistribute.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp new file mode 100644 index 0000000000000..86e07697f437c --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -0,0 +1,662 @@ +//===- XeGPUSubgroupDistribute.cpp - XeGPU Subgroup Distribute Pass -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Transforms/Passes.h" +#include "mlir/Dialect/XeGPU/Transforms/Transforms.h" +#include "mlir/IR/Builders.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace xegpu { +#define GEN_PASS_DEF_XEGPUSUBGROUPDISTRIBUTE +#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" +} // namespace xegpu +} // namespace mlir + +using namespace mlir; +using namespace mlir::dataflow; + +/// HW dependent constants. +/// TODO: These constants should be queried from the target information. +constexpr unsigned subgroupSize = 16; // How many work items in a subgroup. +/// If DPAS A or B operands have low precision element types they must be packed +/// according to the following sizes. +constexpr unsigned packedSizeInBitsForDefault = + 16; // Minimum packing size per register for DPAS A. +constexpr unsigned packedSizeInBitsForDpasB = + 32; // Minimum packing size per register for DPAS B. + +namespace { + +///===----------------------------------------------------------------------===/// +/// Layout +///===----------------------------------------------------------------------===/// + +/// Helper class to store the ND layout of work items within a subgroup and data +/// owned by each work item. +struct Layout { + SmallVector layout; + Layout() = default; + Layout(const Layout &other) = default; + Layout(std::initializer_list list) : layout(list) {} + void print(llvm::raw_ostream &os) const; + size_t size() const { return layout.size(); } + int64_t operator[](size_t idx) const; +}; + +void Layout::print(llvm::raw_ostream &os) const { + os << "["; + llvm::interleaveComma(layout, os); + os << "]"; +} + +int64_t Layout::operator[](size_t idx) const { + assert(idx < layout.size() && "Index out of bounds."); + return layout[idx]; +} + +/// WiLayout represents the layout of work items within a subgroup when it +/// accesses some value. WiData represents the layout of data owned by each work +/// item. +using WiLayout = Layout; +using WiData = Layout; + +///===----------------------------------------------------------------------===/// +/// SGMap +///===----------------------------------------------------------------------===/// + +/// Helper class for tracking the analysis state of a value. For SGPropagation, +/// the analysis state is simply the wi_layout and wi_data of each value. +/// Purpose of this analysis to propagate some unique layout for each value in +/// the program starting from some known values (like DPAS, StoreNd, etc.). +/// +/// Given this, SGMap satisifies the following properties: +/// 1) SGMap is a lattice with two states - assigned and not assigned. +/// 2) Two SGMap values are equal if they are both assigned or both not +/// assigned. The concrete value of assigned state does not matter. +/// 3) The meet operator works as follows: +/// - If current state is assigned, return the current state. (already +/// a unique layout is assigned. don't change it) +/// - Otherwise, return the other state. + +struct SGMap { +private: + WiLayout wiLayout; + WiData wiData; + +public: + SGMap() = default; + SGMap(const SGMap &other) = default; + SGMap(const WiLayout &layout, const WiData &data) + : wiLayout(layout), wiData(data) {} + + /// Two lattice values are equal if they have `some` layout. The actual + /// content of the layout does not matter. + bool operator==(const SGMap &other) const { + return this->isAssigned() == other.isAssigned(); + } + + static SGMap meet(const SGMap &lhs, const SGMap &rhs); + + static SGMap join(const SGMap &lhs, const SGMap &rhs); + + void print(raw_ostream &os) const; + + bool isAssigned() const { return wiLayout.size() > 0 && wiData.size() > 0; } + + SGMap getTransposedLayout(ArrayRef permutation) const; + + const WiLayout &getLayout() const { return wiLayout; } + const WiData &getData() const { return wiData; } +}; + +void SGMap::print(raw_ostream &os) const { + if (isAssigned()) { + os << "wi_layout: "; + wiLayout.print(os); + os << ", wi_data: "; + wiData.print(os); + } else + os << "Not assigned."; +} + +SGMap SGMap::meet(const SGMap &lhs, const SGMap &rhs) { + if (!lhs.isAssigned()) + return rhs; + return lhs; +} + +/// Since this is a backward analysis, join method is not used. +SGMap SGMap::join(const SGMap &lhs, const SGMap &rhs) { + llvm_unreachable("Join should not be triggered by SGMapPropagation."); +} + +/// Get the transposed layout according to the given permutation. +SGMap SGMap::getTransposedLayout(ArrayRef permutation) const { + if (!isAssigned()) + return {}; + WiLayout newLayout; + WiData newData; + for (auto idx : permutation) { + newLayout.layout.push_back(wiLayout.layout[idx]); + newData.layout.push_back(wiData.layout[idx]); + } + return SGMap(newLayout, newData); +} + +///===----------------------------------------------------------------------===/// +/// SGMapLattice +///===----------------------------------------------------------------------===/// + +/// Lattice holding the SGMap for each value. +struct SGMapLattice : public Lattice { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SGMapLattice) + using Lattice::Lattice; +}; + +/// Helper Functions to get default layouts. A `default layout` is a layout that +/// is assigned to a value when the layout is not fixed by some anchor operation +/// (like DPAS). This is the natural layout work items are arranged in a +/// subgroup. + +/// Helper Function to get the default layout for uniform values like constants. +/// For 1D vector, wi_layout is [subgroupSize] and wi_data is [1]. +/// For 2D vector, wi_layout is [1, subgroupSize] and wi_data is [1, 1]. +static SGMap getDefaultSgMap(unsigned rank) { + assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector."); + if (rank == 1) + return SGMap(WiLayout({subgroupSize}), WiData({1})); + return SGMap(WiLayout({1, subgroupSize}), WiData({1, 1})); +} + +/// Helper to get the default layout for a vector type. +static SGMap getDefaultSgMap(VectorType vectorTy) { + /// Expecting a 1D or 2D vector. + assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) && + "Expected 1D or 2D vector."); + /// Expecting int or float element type. + assert(vectorTy.getElementType().isIntOrFloat() && + "Expected int or float element type."); + /// If the rank is 1, then return default layout for 1D vector. + if (vectorTy.getRank() == 1) + return getDefaultSgMap(1); + /// Packing factor is determined by the element type bitwidth. + int packingFactor = 1; + auto bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth(); + if (bitwidth < packedSizeInBitsForDefault) + packingFactor = packedSizeInBitsForDefault / bitwidth; + return SGMap(WiLayout({1, subgroupSize}), WiData({1, packingFactor})); +} + +/// Helper Function to get the expected layouts for DPAS operands. `wi_data` is +/// set according to the following criteria: +/// * For A operand, the data must be packed in minimum +/// `packedSizeInBitsForDefault` +/// * For B operand, the data must be packed in minimum +/// `packedSizeInBitsForDpasB` +static SGMap getSGMapForDPASOperand(VectorType vectorTy, unsigned operandNum) { + auto elementTy = vectorTy.getElementType(); + assert(elementTy.isIntOrFloat() && + "Expected int or float type in DPAS operands"); + WiLayout layout({1, subgroupSize}); + /// For B operand, data must be packed in minimum `packedDpasBSizeInBits` and + /// must have the VNNI format. + if (operandNum == 1 && + elementTy.getIntOrFloatBitWidth() < packedSizeInBitsForDpasB) { + WiData data( + {packedSizeInBitsForDpasB / elementTy.getIntOrFloatBitWidth(), 1}); + return SGMap(layout, data); + } + /// Otherwise, return the default layout for the vector type. + return getDefaultSgMap(vectorTy); +} + +///===----------------------------------------------------------------------===/// +/// SGMapPropagation +///===----------------------------------------------------------------------===/// + +/// Backward data flow analysis to propagate the wi_layout and wi_data of each +/// value in the program. Currently, the layouts for operands DPAS, StoreNd, and +/// StoreScatter are fixed (known before propagation). Purpose of this analysis +/// is to propagate those known layouts to all their producers and (other) +/// consumers. +class SGMapPropagation : public SparseBackwardDataFlowAnalysis { +private: + void visitDpasOp(xegpu::DpasOp dpas, ArrayRef operands, + ArrayRef results); + + void visitStoreNdOp(xegpu::StoreNdOp store, ArrayRef operands, + ArrayRef results); + + void visitStoreScatterOp(xegpu::StoreScatterOp storeScatter, + ArrayRef operands, + ArrayRef results); + + void visitLoadNdOp(xegpu::LoadNdOp load, ArrayRef operands, + ArrayRef results); + + void visitLoadGatherOp(xegpu::LoadGatherOp load, + ArrayRef operands, + ArrayRef results); + + void visitTransposeOp(vector::TransposeOp transpose, + ArrayRef operands, + ArrayRef results); + + void visitVectorBitcastOp(vector::BitCastOp bitcast, + ArrayRef operands, + ArrayRef results); + + void visitCreateDescOp(xegpu::CreateDescOp createDesc, + ArrayRef operands, + ArrayRef results); + + void visitUpdateNdOffsetOp(xegpu::UpdateNdOffsetOp updateNdOffset, + ArrayRef operands, + ArrayRef results); + + void visitVectorMultiReductionOp(vector::MultiDimReductionOp reduction, + ArrayRef operands, + ArrayRef results); + +public: + SGMapPropagation(DataFlowSolver &solver, SymbolTableCollection &symbolTable) + : SparseBackwardDataFlowAnalysis(solver, symbolTable) {} + using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis; + + LogicalResult visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override; + + void visitBranchOperand(OpOperand &operand) override {}; + + void visitCallOperand(OpOperand &operand) override {}; + + void visitExternalCall(CallOpInterface call, + ArrayRef operands, + ArrayRef results) override {}; + + void setToExitState(SGMapLattice *lattice) override { + (void)lattice->meet(SGMap()); + } +}; +} // namespace + +LogicalResult +SGMapPropagation::visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) { + TypeSwitch(op) + .Case( + [&](auto dpasOp) { visitDpasOp(dpasOp, operands, results); }) + .Case( + [&](auto storeNdOp) { visitStoreNdOp(storeNdOp, operands, results); }) + .Case([&](auto storeScatterOp) { + visitStoreScatterOp(storeScatterOp, operands, results); + }) + .Case( + [&](auto loadNdOp) { visitLoadNdOp(loadNdOp, operands, results); }) + .Case([&](auto loadGatherOp) { + visitLoadGatherOp(loadGatherOp, operands, results); + }) + .Case([&](auto createDescOp) { + visitCreateDescOp(createDescOp, operands, results); + }) + .Case([&](auto updateNdOffsetOp) { + visitUpdateNdOffsetOp(updateNdOffsetOp, operands, results); + }) + /// No need to propagate the layout to operands in CreateNdDescOp because + /// they are scalars (offsets, sizes, etc.). + .Case([&](auto createNdDescOp) {}) + .Case([&](auto transposeOp) { + visitTransposeOp(transposeOp, operands, results); + }) + .Case([&](auto bitcastOp) { + visitVectorBitcastOp(bitcastOp, operands, results); + }) + .Case([&](auto reductionOp) { + visitVectorMultiReductionOp(reductionOp, operands, results); + }) + /// All other ops. + .Default([&](Operation *op) { + for (const SGMapLattice *r : results) { + for (SGMapLattice *operand : operands) { + /// Propagate the layout of the result to the operand. + if (r->getValue().isAssigned()) + meet(operand, *r); + } + } + }); + /// Add a dependency from each result to program point after the operation. + for (const SGMapLattice *r : results) { + addDependency(const_cast(r), getProgramPointAfter(op)); + } + return success(); +} + +void SGMapPropagation::visitVectorMultiReductionOp( + vector::MultiDimReductionOp reduction, ArrayRef operands, + ArrayRef results) { + /// The layout of the result must be present. + auto resultLayout = results[0]->getValue(); + if (!resultLayout.isAssigned()) + return; + /// We only consider 2D -> 1D reductions at this point. + assert(resultLayout.getLayout().size() == 1 && + "Expected 1D layout for reduction result."); + /// Given that the result is 1D, the layout of the operand should be 2D with + /// default layout. + auto operandLayout = getDefaultSgMap(2); + propagateIfChanged(operands[0], operands[0]->meet(operandLayout)); + /// Accumulator should have the same layout as the result. + propagateIfChanged(operands[1], operands[1]->meet(resultLayout)); +} + +/// Propagate the layout of the result tensor to the source tensor descriptor in +/// UpdateNdOffsetOp. +void SGMapPropagation::visitUpdateNdOffsetOp( + xegpu::UpdateNdOffsetOp updateNdOffset, ArrayRef operands, + ArrayRef results) { + /// The layout of the result must be present. + auto resultLayout = results[0]->getValue(); + if (!resultLayout.isAssigned()) + return; + /// Propagate the layout to the source operand. + propagateIfChanged(operands[0], operands[0]->meet(resultLayout)); +} + +/// Set the layouts for DPAS A, B, and C operands. +void SGMapPropagation::visitDpasOp(xegpu::DpasOp dpas, + ArrayRef operands, + ArrayRef results) { + auto aTy = dpas.getLhsType(); + auto bTy = dpas.getRhsType(); + propagateIfChanged(operands[0], + operands[0]->meet(getSGMapForDPASOperand(aTy, 0))); + propagateIfChanged(operands[1], + operands[1]->meet(getSGMapForDPASOperand(bTy, 1))); + if (operands.size() > 2) { + auto cTy = dpas.getAccType(); + propagateIfChanged(operands[2], + operands[2]->meet(getSGMapForDPASOperand(cTy, 2))); + } +}; + +/// Set the layout for the value and tensor descriptor operands in StoreNdOp. +void SGMapPropagation::visitStoreNdOp(xegpu::StoreNdOp store, + ArrayRef operands, + ArrayRef results) { + auto storeLayout = getDefaultSgMap(store.getValueType()); + /// Both operands should have the same layout + for (SGMapLattice *operand : operands) { + propagateIfChanged(operand, operand->meet(storeLayout)); + } +} + +/// Propagate the layout of the value to the tensor descriptor operand in +/// LoadNdOp. +void SGMapPropagation::visitLoadNdOp(xegpu::LoadNdOp load, + ArrayRef operands, + ArrayRef results) { + auto valueLayout = results[0]->getValue(); + /// Need the layout of the value to propagate to the tensor descriptor. + if (!valueLayout.isAssigned()) + return; + SGMap tensorDescLayout = valueLayout; + /// LoadNdOp has the transpose effect. However, at the stage of this analysis + /// this effect is not expected and should be abstracted away. Emit a warning. + if (auto transpose = load.getTranspose()) { + load.emitWarning("Transpose effect is not expected for LoadNdOp at " + "SGMapPropagation stage."); + tensorDescLayout = valueLayout.getTransposedLayout(transpose.value()); + } + /// Propagate the new layout to the tensor descriptor operand. + propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout)); +} + +/// For vector::TransposeOp, the layout of the result is transposed and +/// propagated to the operand. +void SGMapPropagation::visitTransposeOp( + vector::TransposeOp transpose, ArrayRef operands, + ArrayRef results) { + /// Need the layout of transpose result to propagate to the operands. + auto resultLayout = results[0]->getValue(); + if (!resultLayout.isAssigned()) + return; + auto newLayout = resultLayout.getTransposedLayout(transpose.getPermutation()); + /// Propagate the new layout to the vector operand. + propagateIfChanged(operands[0], operands[0]->meet(newLayout)); +} + +/// For vector::BitCastOp, the wi_data of the source layout is changed based on +/// the bit width of the source and result types. +void SGMapPropagation::visitVectorBitcastOp( + vector::BitCastOp bitcast, ArrayRef operands, + ArrayRef results) { + /// Need the layout of bitcast result to propagate to the operands. + auto resultLayout = results[0]->getValue(); + if (!resultLayout.isAssigned()) + return; + auto inElemTyBitWidth = + bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth(); + auto outElemTyBitWidth = + bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth(); + + /// WiLayout does not change. + const WiLayout &newWiLayout = resultLayout.getLayout(); + const WiData &currData = resultLayout.getData(); + WiData newWiData; + /// It's a widening bitcast + if (inElemTyBitWidth < outElemTyBitWidth) { + auto ratio = outElemTyBitWidth / inElemTyBitWidth; + newWiData = resultLayout.getData()[0] == 1 + ? WiData({1, currData[1] * ratio}) + : WiData({currData[0] * ratio, 1}); + } else { + /// It's a narrowing bitcast + auto ratio = inElemTyBitWidth / outElemTyBitWidth; + newWiData = resultLayout.getData()[0] == 1 + ? WiData({1, currData[1] / ratio}) + : WiData({currData[0] / ratio, 1}); + } + + propagateIfChanged(operands[0], + operands[0]->meet(SGMap(newWiLayout, newWiData))); +} + +/// Propagate the layout of the result to the tensor descriptor and mask +/// operands in LoadGatherOp. +void SGMapPropagation::visitLoadGatherOp( + xegpu::LoadGatherOp load, ArrayRef operands, + ArrayRef results) { + auto valueLayout = results[0]->getValue(); + /// Need the layout of the value to propagate to the tensor descriptor. + if (!valueLayout.isAssigned()) + return; + + SGMap tensorDescLayout = valueLayout; + if (load.getTranspose()) { + /// LoadGatherOp has the transpose effect. However, at the stage of this + /// analyis this effect is not expected and should be abstracted away. Emit + /// a warning. + load.emitWarning("Transpose effect is not expected for LoadGatherOp at " + "SGMapPropagation stage."); + tensorDescLayout = valueLayout.getTransposedLayout({1, 0}); + } + /// Mask operand should have 1D default layout. + auto maskLayout = getDefaultSgMap(1); + /// Propagate the new layout to the tensor descriptor operand. + propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout)); + /// Propagate the new layout to the mask operand. + propagateIfChanged(operands[1], operands[1]->meet(maskLayout)); +} + +/// Propagate the layout of the descriptor to the vector offset operand in +/// CreateDescOp. +void SGMapPropagation::visitCreateDescOp( + xegpu::CreateDescOp createDesc, ArrayRef operands, + ArrayRef results) { + auto descLayout = results[0]->getValue(); + /// Need the layout of the descriptor to propagate to the operands. + if (!descLayout.isAssigned()) + return; + /// For offset operand propagate 1D default layout. + SGMap layout = getDefaultSgMap(1); + propagateIfChanged(operands[1], operands[1]->meet(layout)); +} + +/// Set the layout for the value, tensor descriptor, and mask operands in the +/// StoreScatterOp. +void SGMapPropagation::visitStoreScatterOp( + xegpu::StoreScatterOp storeScatter, ArrayRef operands, + ArrayRef results) { + /// Currently, for 2D StoreScatterOp we expect that the height dimension of + /// the tensor descriptor is evenly divisible by the subgroup size. + /// TODO: Add support for other 2D shapes. + auto tdescShape = storeScatter.getTensorDescType().getShape(); + if (tdescShape.size() > 1 && tdescShape[0] % subgroupSize != 0) { + storeScatter.emitError("Height dimension of the tensor descriptor should " + "be evenly divisible by the subgroup size."); + return; + } + auto valueLayout = getDefaultSgMap(storeScatter.getValueType()); + SGMap storeScatterLayout = valueLayout; + if (storeScatter.getTranspose()) { + /// StoreScatteOp allows transpose effect. However, at the stage of this + /// analyis this effect is not expected and should be abstracted away. Emit + /// a warning. + storeScatter.emitWarning("Transpose effect is not expected for " + "StoreScatterOp at SGMapPropagation stage."); + storeScatterLayout = valueLayout.getTransposedLayout({1, 0}); + } + /// Propagate the value layout. + propagateIfChanged(operands[0], operands[0]->meet(valueLayout)); + /// Propagate the tensor descriptor layout. + propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout)); + /// Use default 1D layout for mask operand. + auto maskLayout = getDefaultSgMap(1); + propagateIfChanged(operands[2], operands[2]->meet(maskLayout)); +} + +namespace { + +///===----------------------------------------------------------------------===/// +/// RunSGMapPropagation +///===----------------------------------------------------------------------===/// + +/// Driver class for running the SGMapPropagation analysis. +class RunSGMapPropagation { +public: + RunSGMapPropagation(Operation *op) : target(op) { + SymbolTableCollection symbolTable; + solver.load(); + solver.load(); + solver.load(symbolTable); + (void)solver.initializeAndRun(op); + } + + SGMap getSGMap(Value val); + + void printAnalysisResult(llvm::raw_ostream &os); + +private: + DataFlowSolver solver; + const Operation *target; +}; +} // namespace + +SGMap RunSGMapPropagation::getSGMap(Value val) { + auto *state = solver.lookupState(val); + if (!state) + return {}; + return state->getValue(); +} + +void RunSGMapPropagation::printAnalysisResult(llvm::raw_ostream &os) { + auto printFunctionResult = [&](FunctionOpInterface funcOp) { + os << "function: " << funcOp.getName() << ":\n"; + // Function arguments + for (auto arg : funcOp.getArguments()) { + auto layout = getSGMap(arg); + os << "argument: " << arg << "\n"; + os << "sg_map : "; + layout.print(os); + os << "\n"; + } + // Function ops + funcOp.walk([&](Operation *op) { + // Skip ops that do not have results + if (op->getResults().empty()) + return; + os << "op : "; + /// For control-flow ops, print the op name only. + if (isa(op) || isa(op)) + os << op->getName(); + else + op->print(os); + os << "\n"; + /// Print the sg_map for each result. + for (auto [i, r] : llvm::enumerate(op->getResults())) { + auto layout = getSGMap(r); + os << "sg_map for result #" << i << ": "; + layout.print(os); + os << "\n"; + } + }); + }; + + SmallVector funcOps; + if (auto modOp = dyn_cast(target)) { + for (auto funcOp : modOp.getOps()) { + funcOps.push_back(funcOp); + } + /// Collect all GpuFuncOps in the module. + for (auto gpuModOp : modOp.getOps()) { + for (auto gpuFuncOp : gpuModOp.getOps()) { + funcOps.push_back(gpuFuncOp); + } + } + } + /// Print the analysis result for each function. + for (auto funcOp : funcOps) { + printFunctionResult(funcOp); + } +} + +namespace { +struct XeGPUSubgroupDistributePass final + : public xegpu::impl::XeGPUSubgroupDistributeBase< + XeGPUSubgroupDistributePass> { + XeGPUSubgroupDistributePass() = default; + XeGPUSubgroupDistributePass(const XeGPUSubgroupDistributePass &other) = + default; + XeGPUSubgroupDistributePass(xegpu::XeGPUSubgroupDistributeOptions options) + : XeGPUSubgroupDistributeBase(options) {} + void runOnOperation() override; +}; +} // namespace + +void XeGPUSubgroupDistributePass::runOnOperation() { + Operation *op = getOperation(); + RunSGMapPropagation solver(op); + + // Print the analysis result and exit. + if (printOnly) { + auto &os = llvm::outs(); + solver.printAnalysisResult(os); + return; + } +} diff --git a/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir new file mode 100644 index 0000000000000..1ae4348af33e6 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/subgroup-map-propagation.mlir @@ -0,0 +1,563 @@ +// RUN: mlir-opt -xegpu-subgroup-distribute='print-analysis-only=true' -split-input-file %s | FileCheck %s + +// CHECK: function: test_dpas_f16: +// CHECK-NEXT: argument: of type 'memref<8x16xf16>' at index: 0 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type 'memref<16x16xf16>' at index: 1 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type 'memref<8x16xf32>' at index: 2 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index +// CHECK-NEXT: sg_map for result #0: Not assigned. +// CHECK-NEXT: op : %{{.*}} = arith.constant dense<0.000000e+00> : vector<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[T2:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %[[T1]] : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %{{.*}} = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +func.func @test_dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32> + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + return +} + + +// ----- +// CHECK: function: test_dpas_i8: +// CHECK-NEXT: argument: of type 'vector<8x32xi8>' at index: 0 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 2] +// CHECK-NEXT: argument: of type 'vector<32x16xi8>' at index: 1 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [4, 1] +// CHECK-NEXT: argument: of type 'memref<8x16xi32>' at index: 2 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index +// CHECK-NEXT: sg_map for result #0: Not assigned. +// CHECK-NEXT: op : %[[T0:.*]] = xegpu.dpas %{{.*}} : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +func.func @test_dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) { + %c0 = arith.constant 0 : index + %0 = xegpu.dpas %arg0, %arg1 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32> + %1 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32> + xegpu.store_nd %0, %1 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32> + return +} + +// ----- +// CHECK: function: test_load_with_transpose_effect: +// CHECK-NEXT: argument: of type 'memref<8x16xf16>' at index: 0 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type 'memref<16x16xf16>' at index: 1 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type 'memref<8x16xf32>' at index: 2 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index +// CHECK-NEXT: sg_map for result #0: Not assigned. +// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2] +// CHECK-NEXT: op : %[[T2:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %[[T1]] <{transpose = array}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +func.func @test_load_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32> + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %3 = xegpu.load_nd %1 <{transpose = array}> : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %4 = xegpu.dpas %2, %3, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + return +} + +// ----- +// CHECK: function: test_vector_transpose: +// CHECK-NEXT: argument: of type 'memref<8x16xf16>' at index: 0 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type 'memref<16x16xf16>' at index: 1 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type 'memref<8x16xf32>' at index: 2 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index +// CHECK-NEXT: sg_map for result #0: Not assigned. +// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2] +// CHECK-NEXT: op : %[[T2:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %[[T1]] : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2] +// CHECK-NEXT: op : %[[T4:.*]] = vector.transpose %[[T3]], [1, 0] : vector<16x16xf16> to vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[T5:.*]] = xegpu.dpas %[[T2]], %[[T4]], %[[CST]] : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +func.func @test_vector_transpose(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32> + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %4 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16> + %5 = xegpu.dpas %2, %4, %cst : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %5, %6 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + return +} + +// ----- +// CHECK: function: test_extf_truncf: +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<8x16xf16>' at index: 0 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<16x16xf16>' at index: 1 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[T0:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T1:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[T2:.*]] = arith.extf %[[T1]] : vector<16x16xf16> to vector<16x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[T3:.*]] = arith.truncf %[[T2]] : vector<16x16xf32> to vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> +// CHECK-NEXT: sg_map for result #0: Not assigned. +func.func @test_extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) -> vector<8x16xf32> { + %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %2 = arith.extf %1 : vector<16x16xf16> to vector<16x16xf32> + %3 = arith.truncf %2 : vector<16x16xf32> to vector<16x16xf16> + %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + return %4 : vector<8x16xf32> +} + +// ----- +// CHECK: function: test_load_gather_with_transpose_effect: +// CHECK-NEXT: argument: of type 'memref<8x16xf16>' at index: 0 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type 'memref<256xf16>' at index: 1 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type 'memref<8x16xf32>' at index: 2 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index +// CHECK-NEXT: sg_map for result #0: Not assigned. +// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1] +// CHECK-NEXT: op : %[[CST0:.*]] = arith.constant dense : vector<16xi1> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1] +// CHECK-NEXT: op : %[[T2:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 2] +// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load %[[T2]], %[[CST0]] <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[T4:.*]] = xegpu.dpas %[[T1]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T5:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +func.func @test_load_gather_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) { + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex> + %cst_0 = arith.constant dense : vector<16xi1> + %2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr> + %3 = xegpu.load %2, %cst_0 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr>, vector<16xi1> -> vector<16x16xf16> + %4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %4, %5 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + return +} + +// ----- +// CHECK: function: test_load_gather_1d: +// CHECK: argument: of type 'memref<256xf32>' at index: 0 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<16xf32>' at index: 1 +// CHECK-NEXT: sg_map : wi_layout: [16], wi_data: [1] +// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1] +// CHECK-NEXT: op : %[[CST0:.*]] = arith.constant dense : vector<16xi1> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1] +// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1] +// CHECK-NEXT: op : %[[T1]] = xegpu.load %[[T0]], %[[CST0]] : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1] +func.func @test_load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) { + %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex> + %cst_0 = arith.constant dense : vector<16xi1> + %0 = xegpu.create_tdesc %arg0, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + %1 = xegpu.load %0, %cst_0 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32> + xegpu.store_nd %1, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32> + return +} + +// ----- +// CHECK: function: test_store_scatter_with_transpose_effect: +// CHECK-NEXT: argument: of type 'memref<128xf32>' at index: 0 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[CST0:.*]] = arith.constant dense : vector<16xi1> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1] +// CHECK-NEXT: op : %[[CST1:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1] +// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST1]] : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16, 1], wi_data: [1, 1] +func.func @test_store_scatter_with_transpose_effect(%arg0: memref<128xf32>) { + %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32> + %cst_0 = arith.constant dense : vector<16xi1> + %cst_1 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex> + %0 = xegpu.create_tdesc %arg0, %cst_1 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr> + xegpu.store %cst, %0, %cst_0 <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + return +} + +// ----- +// CHECK: function: test_store_scatter_1d: +// CHECK-NEXT: argument: of type 'vector<16xf32>' at index: 0 +// CHECK-NEXT: sg_map : wi_layout: [16], wi_data: [1] +// CHECK-NEXT: argument: of type 'memref<256xf32>' at index: 1 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1] +// CHECK-NEXT: op : %[[CST1:.*]] = arith.constant dense : vector<16xi1> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1] +// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_tdesc %{{.*}}, %[[CST]] : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1] +func.func @test_store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) { + %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex> + %cst_0 = arith.constant dense : vector<16xi1> + %0 = xegpu.create_tdesc %arg1, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>> + xegpu.store %arg0, %0, %cst_0 : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> + return +} + +// ----- +// CHECK: function: test_vector_bitcast_i16_to_i8: +// CHECK-NEXT: argument: of type 'memref<8x16xi16>' at index: 0 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type 'memref<32x16xi8>' at index: 1 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type 'memref<8x16xi32>' at index: 2 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index +// CHECK-NEXT: sg_map for result #0: Not assigned. +// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1] +// CHECK-NEXT: op : %[[T2:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %[[T1]] : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1] +// CHECK-NEXT: op : %[[T4:.*]] = vector.bitcast %[[T2]] : vector<8x16xi16> to vector<8x32xi8> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2] +// CHECK-NEXT: op : %[[T5:.*]] = xegpu.dpas %[[T4]], %[[T3]] : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +func.func @test_vector_bitcast_i16_to_i8(%arg0: memref<8x16xi16>, %arg1: memref<32x16xi8>, %arg2: memref<8x16xi32>) { + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xi16> -> !xegpu.tensor_desc<8x16xi16> + %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xi8> -> !xegpu.tensor_desc<32x16xi8> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xi16> -> vector<8x16xi16> + %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<32x16xi8> -> vector<32x16xi8> + %4 = vector.bitcast %2 : vector<8x16xi16> to vector<8x32xi8> + %5 = xegpu.dpas %4, %3 : vector<8x32xi8>, vector<32x16xi8> -> vector<8x16xi32> + %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xi32> -> !xegpu.tensor_desc<8x16xi32> + xegpu.store_nd %5, %6 : vector<8x16xi32>, !xegpu.tensor_desc<8x16xi32> + return +} + +// ----- +// CHECK: function: test_vector_bitcast_i8_to_f16: +// CHECK-NEXT: argument: of type 'memref<8x32xi8>' at index: 0 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type 'memref<16x32xi8>' at index: 1 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type 'memref<8x16xf32>' at index: 2 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index +// CHECK-NEXT: sg_map for result #0: Not assigned. +// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2] +// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1] +// CHECK-NEXT: op : %[[T2:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 2] +// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %[[T1]] : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [4, 1] +// CHECK-NEXT: op : %[[T4:.*]] = vector.bitcast %[[T2]] : vector<8x32xi8> to vector<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T5:.*]] = vector.bitcast %[[T3]] : vector<16x32xi8> to vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T7:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +func.func @test_vector_bitcast_i8_to_f16(%arg0: memref<8x32xi8>, %arg1: memref<16x32xi8>, %arg2: memref<8x16xf32>) { + %c0 = arith.constant 0 : index + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xi8> -> !xegpu.tensor_desc<8x32xi8> + %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<16x32xi8> -> !xegpu.tensor_desc<16x32xi8> + %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x32xi8> -> vector<8x32xi8> + %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<16x32xi8> -> vector<16x32xi8> + %4 = vector.bitcast %2 : vector<8x32xi8> to vector<8x16xf16> + %5 = vector.bitcast %3 : vector<16x32xi8> to vector<16x16xf16> + %6 = xegpu.dpas %4, %5 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + %7 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %6, %7 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + return +} + +// ----- +// CHECK: function: test_binary_op_one_use: +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<8x16xf16>' at index: 0 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<16x16xf16>' at index: 1 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<8x16xf32>' at index: 2 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T0:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T1:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[T2:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[T3:.*]] = arith.addf %[[T1]], %[[T2]] : vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[T4:.*]] = xegpu.dpas %[[T0]], %[[T3]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +func.func @test_binary_op_one_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>) { + %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %2 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %3 = arith.addf %1, %2 : vector<16x16xf16> + %4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + xegpu.store_nd %4, %arg2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + return +} + +// ----- +// CHECK: function: test_binary_op_multiple_uses: +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<8x16xf16>' at index: 0 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<16x16xf16>' at index: 1 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<8x16xf32>' at index: 2 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<16x16xf16>' at index: 3 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T0:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T1:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T2:.*]] = arith.addf %[[T1]], %[[CST]] : vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T3:.*]] = xegpu.dpas %[[T0]], %[[T2]] : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +func.func @test_binary_op_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: !xegpu.tensor_desc<8x16xf32>, %arg3: !xegpu.tensor_desc<16x16xf16>) { + %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %1 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %cst = arith.constant dense<1.000000e+00> : vector<16x16xf16> + %2 = arith.addf %1, %cst : vector<16x16xf16> + %3 = xegpu.dpas %0, %2 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + xegpu.store_nd %3, %arg2 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %2, %arg3 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16> + return +} + +// ----- +// CHECK: function: test_for_op: +// CHECK-NEXT: argument: of type 'memref<8x128xf16>' at index: 0 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type 'memref<128x16xf16>' at index: 1 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type 'memref<8x16xf32>' at index: 2 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: op : %{{.*}} = arith.constant 0 : index +// CHECK-NEXT: sg_map for result #0: Not assigned. +// CHECK-NEXT: op : %{{.*}} = arith.constant 128 : index +// CHECK-NEXT: sg_map for result #0: Not assigned. +// CHECK-NEXT: op : %{{.*}} = arith.constant 16 : index +// CHECK-NEXT: sg_map for result #0: Not assigned. +// CHECK-NEXT: op : %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T1:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T4:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T5:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[T6:.*]] = xegpu.dpas %[[T4]], %[[T5]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T7:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T8:.*]] = xegpu.update_nd_offset %{{.*}} : !xegpu.tensor_desc<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : scf.for +// CHECK-NEXT: sg_map for result #0: Not assigned. +// CHECK-NEXT: sg_map for result #1: Not assigned. +// CHECK-NEXT: sg_map for result #2: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T3:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +func.func @test_for_op(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c16 = arith.constant 16 : index + %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x128xf16> -> !xegpu.tensor_desc<8x16xf16> + %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<128x16xf16> -> !xegpu.tensor_desc<16x16xf16> + %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32> + %2:3 = scf.for %arg3 = %c0 to %c128 step %c16 iter_args(%arg4 = %0, %arg5 = %1, %arg6 = %cst) -> (!xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32>) { + %4 = xegpu.load_nd %arg4 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %5 = xegpu.load_nd %arg5 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + %6 = xegpu.dpas %4, %5, %arg6 : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32> + %7 = xegpu.update_nd_offset %arg4, [%c0, %c16] : !xegpu.tensor_desc<8x16xf16> + %8 = xegpu.update_nd_offset %arg5, [%c16, %c0] : !xegpu.tensor_desc<16x16xf16> + scf.yield %7, %8, %6 : !xegpu.tensor_desc<8x16xf16>, !xegpu.tensor_desc<16x16xf16>, vector<8x16xf32> + } + %3 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %2#2, %3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + return +} + +// ----- +// CHECK: function: test_if_single_use: +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<8x16xf16>' at index: 0 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<16x16xf16>' at index: 1 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: argument: of type 'i1' at index: 2 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<8x16xf32>' at index: 3 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T0:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[T4:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : scf.if +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [2, 1] +// CHECK-NEXT: op : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +func.func @test_if_single_use(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>) { + %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %1 = scf.if %arg2 -> (vector<16x16xf16>) { + %3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + scf.yield %3 : vector<16x16xf16> + } else { + %3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + scf.yield %3 : vector<16x16xf16> + } + %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + xegpu.store_nd %2, %arg3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + return +} + +// ----- +// CHECK: function: test_if_multiple_uses: +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<8x16xf16>' at index: 0 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<16x16xf16>' at index: 1 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: argument: of type 'i1' at index: 2 +// CHECK-NEXT: sg_map : Not assigned. +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<8x16xf32>' at index: 3 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<16x16xf16>' at index: 4 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T0:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T3:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T4:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : scf.if +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: op : %[[T2:.*]] = xegpu.dpas %[[T0]], %{{.*}} : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [1, 16], wi_data: [1, 1] +func.func @test_if_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>, %arg2: i1, %arg3: !xegpu.tensor_desc<8x16xf32>, %arg4: !xegpu.tensor_desc<16x16xf16>) { + %0 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %1 = scf.if %arg2 -> (vector<16x16xf16>) { + %3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + scf.yield %3 : vector<16x16xf16> + } else { + %3 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16> + scf.yield %3 : vector<16x16xf16> + } + %2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32> + xegpu.store_nd %2, %arg3 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + xegpu.store_nd %1, %arg4 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16> + return +} + +// ----- +// CHECK: function: test_vector_outer_reduction: +// CHECK-NEXT: argument: of type 'vector<16x16xf32>' at index: 0 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<16xf32>' at index: 1 +// CHECK-NEXT: sg_map : wi_layout: [16], wi_data: [1] +// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1] +// CHECK-NEXT: op : %[[T0:.*]] = vector.multi_reduction , %{{.*}}, %[[CST]] [0] : vector<16x16xf32> to vector<16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1] +func.func @test_vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) { + %cst = arith.constant dense<0.000000e+00> : vector<16xf32> + %0 = vector.multi_reduction , %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32> + xegpu.store_nd %0, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32> + return +} + +// ----- +// CHECK: function: test_vector_inner_reduction: +// CHECK-NEXT: argument: of type 'vector<16x16xf32>' at index: 0 +// CHECK-NEXT: sg_map : wi_layout: [1, 16], wi_data: [1, 1] +// CHECK-NEXT: argument: of type '!xegpu.tensor_desc<16xf32>' at index: 1 +// CHECK-NEXT: sg_map : wi_layout: [16], wi_data: [1] +// CHECK-NEXT: op : %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1] +// CHECK-NEXT: op : %[[T0:.*]] = vector.multi_reduction , %{{.*}}, %[[CST]] [1] : vector<16x16xf32> to vector<16xf32> +// CHECK-NEXT: sg_map for result #0: wi_layout: [16], wi_data: [1] +func.func @test_vector_inner_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) { + %cst = arith.constant dense<0.000000e+00> : vector<16xf32> + %0 = vector.multi_reduction , %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32> + xegpu.store_nd %0, %arg1 : vector<16xf32>, !xegpu.tensor_desc<16xf32> + return +}