diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td index 6d04ee5599a23..032ce5bc18334 100644 --- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td @@ -303,7 +303,6 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> { return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr, getLaneLayout(), getLaneData(), getOrder()); } - }]; let assemblyFormat = "`<` struct(params) `>`"; diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h index 3e94021c7a1ea..09f9ce1e716c0 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h @@ -14,11 +14,67 @@ class RewritePatternSet; namespace xegpu { +/// Options to control the XeGPU unrolling. Its main purpose is to +/// provide a way to customize the native shape of the operation. +struct UnrollOptions { + /// Callback function that indicates whether vector unrolling should be + /// attempted on the operation. + using FilterConstraintFnType = std::function; + FilterConstraintFnType filterConstraint = nullptr; + UnrollOptions &setFilterConstraint(FilterConstraintFnType constraint) { + filterConstraint = std::move(constraint); + return *this; + } + + /// Function that computes the target shape for unrolling. It returns an + /// optional vector of integers representing the shape. If it returns + /// `std::nullopt`, unrolling is aborted for the given operation. + using NativeShapeFnType = + std::function>(Operation *op)>; + NativeShapeFnType nativeShape = nullptr; + UnrollOptions &setNativeShapeFn(NativeShapeFnType fn) { + nativeShape = std::move(fn); + return *this; + } + + /// Function that converts a ShapedType (TensorDescType or VectorType) + /// into the unrolled type based on the tileShape. It returns a vector of + /// types representing the unrolled types for simplicity. + using UnrolledTypeFnType = std::function( + ShapedType type, ArrayRef tileShape)>; + UnrolledTypeFnType getUnrolledTypes = nullptr; + UnrollOptions &setUnrolledTypesFn(UnrolledTypeFnType fn) { + getUnrolledTypes = std::move(fn); + return *this; + } +}; + /// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`. void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns); + /// Appends patterns for XeGPU SIMT distribution into `patterns`. void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns); +/// Collect a set of patterns to unroll xegpu operations to a smaller shapes. +/// Users can control whether an operation to be unrolled or not, as well as +/// its target shape via `options` structure. (via setting filterConstraint +/// and nativeShape respectively, both of them are function refs taking `op` as +/// input). +/// An `op` is unrolled to the `targetShape` as follows, for each of its +/// operands: +/// 1. the unrolled type `unrolledType` and number of unrolled instances +/// `numUnrolledInstances` are computed from the `targetShape`. +/// 2. pack each operand. ExtractStridedSlice are created to break-up the +/// vector operands. And BuiltinUnrealizedCastop are created to break-up +/// the TensorDesc operands. +/// 3. the original op is cloned `numUnrolledInstances` times, once for each +/// result. +/// 4. unpack the results. InsertStridedSlice are inserted for VectorType +/// result, and BuiltinUnrealizedCastOp are inserted for TensorDescType result +/// to re-assemble the slices into the original shape. +void populateXeGPUUnrollPatterns(RewritePatternSet &patterns, + const UnrollOptions &options); + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index f2cfa50e102f8..c99e925a97633 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/IR/Builders.h" diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt index 901e02d3c9cf5..892eb791c46e7 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms XeGPUFoldAliasOps.cpp XeGPUSubgroupDistribute.cpp + XeGPUUnroll.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp new file mode 100644 index 0000000000000..44d45dd2eaec0 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -0,0 +1,427 @@ +//===- XeGPUUnroll.cpp - patterns to do unrolling ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains patterns for unrolling XeGPU operations. It follows a +// similar concept and design as vector unroll patterns, serving as a complement +// to them. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/XeGPU/Transforms/Passes.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Transforms/Transforms.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include + +namespace mlir { +namespace xegpu { +#define GEN_PASS_DEF_XEGPUUNROLL +#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" +} // namespace xegpu +} // namespace mlir + +#define DEBUG_TYPE "xegpu-unroll" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; + +namespace { + +template +struct UnrollPattern : public OpRewritePattern { + UnrollPattern(MLIRContext *context, const xegpu::UnrollOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} + +protected: + /// Return the target shape for the given `op`. Return std::nullopt if the + /// op shouldn't be or cannot be unrolled. + std::optional> getTargetShape(Operation *op) const { + LDBG(""); + LDBG("Get unroll shape for: " << *op); + + if (options.filterConstraint && failed(options.filterConstraint(op))) { + LDBG("--no filter constraint -> BAIL"); + return std::nullopt; + } + + assert(options.nativeShape && + "expects the native shape for native shape call back function."); + auto nativeShape = options.nativeShape(op); + return nativeShape; + } + + SmallVector getUnrolledTypes(ShapedType type, + ArrayRef tileShape) const { + return options.getUnrolledTypes(type, tileShape); + } + + /// Emulate the the unpack behavior using insert_strided_slice for VectorType + /// values and unrealized_conversion_cast for TensorDescType values. + Value unpack(ValueRange srcs, Type destTy, ArrayRef blockSize, + Location loc, PatternRewriter &rewriter) const { + if (auto vecTy = dyn_cast(destTy)) { + assert(vecTy.getRank() == static_cast(blockSize.size()) && + "Expecting blockSize size to match the rank of destTy."); + auto shape = vecTy.getShape(); + auto zeroAttr = rewriter.getZeroAttr(vecTy.getElementType()); + + Value result = rewriter.create( + loc, vecTy, DenseElementsAttr::get(vecTy, zeroAttr)); + for (auto [src, offsets] : + llvm::zip_equal(srcs, StaticTileOffsetRange(shape, blockSize))) { + SmallVector staticStrides(offsets.size(), 1); + result = rewriter.create( + loc, src, result, offsets, staticStrides); + } + return result; + } + + if (isa(destTy)) { + auto attr = NamedAttribute(rewriter.getStringAttr(unpackAttrName), + rewriter.getUnitAttr()); + auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName), + rewriter.getDenseI64ArrayAttr(blockSize)); + auto castOp = rewriter.create( + loc, destTy, srcs, ArrayRef({attr, blkAttr})); + return castOp.getResult(0); + } + + llvm_unreachable("Unexpected destTy."); + return Value(); + } + + /// Emulate the the pack behavior using extract_strided_slice for VectorType + /// values and unrealized_conversion_cast for TensorDescType values. + SmallVector pack(Value src, TypeRange destTypes, + ArrayRef blockSize, Location loc, + PatternRewriter &rewriter) const { + if (auto vecTy = dyn_cast(src.getType())) { + assert(vecTy.getRank() == static_cast(blockSize.size()) && + "Expecting blockSize size to match the rank of src."); + auto shape = vecTy.getShape(); + SmallVector results; + for (SmallVector offsets : + StaticTileOffsetRange(shape, blockSize)) { + SmallVector staticStrides(offsets.size(), 1); + auto slice = rewriter.create( + loc, src, offsets, blockSize, staticStrides); + results.push_back(slice); + } + return results; + } + + if (isa(src.getType())) { + auto attr = NamedAttribute(rewriter.getStringAttr(packAttrName), + rewriter.getUnitAttr()); + auto blkAttr = NamedAttribute(rewriter.getStringAttr(blockAttrName), + rewriter.getDenseI64ArrayAttr(blockSize)); + auto castOp = rewriter.create( + loc, destTypes, src, ArrayRef({attr, blkAttr})); + return castOp.getResults(); + } + + llvm_unreachable("Unexpected src type."); + return SmallVector(); + } + +private: + const char *const packAttrName = "__xegpu_blocking_pack__"; + const char *const unpackAttrName = "__xegpu_blocking_unpack__"; + const char *const blockAttrName = "__xegpu_blocking_tile_shape__"; + + xegpu::UnrollOptions options; +}; + +struct UnrollCreateNdOp : public UnrollPattern { + using UnrollPattern::UnrollPattern; + LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + xegpu::TensorDescType tdescTy = op.getType(); + int64_t rank = tdescTy.getRank(); + ArrayRef shape = tdescTy.getShape(); + + std::optional> targetShape = getTargetShape(op); + if (!targetShape || llvm::equal(*targetShape, shape)) + return failure(); + + auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0]; + + auto addi = [&](OpFoldResult a, int64_t b) -> Value { + std::optional maybeInt = getConstantIntValue(a); + if (maybeInt) { + return rewriter.create(loc, *maybeInt + b); + } else { + auto aV = llvm::cast(a); + auto bV = rewriter.create(loc, b); + return rewriter.createOrFold(loc, aV, bV); + } + }; + + SmallVector mixedOffsets = op.getMixedOffsets(); + + // For n-D memrefs where n > rank, we need to handle the last `rank` + // dimensions only, and keep the first `n-rank` dimensions as is. + SmallVector oldOffsets = llvm::to_vector( + llvm::drop_begin(mixedOffsets, mixedOffsets.size() - rank)); + auto validIdxes = + llvm::seq(mixedOffsets.size() - rank, mixedOffsets.size()); + + SmallVector newOps; + for (SmallVector offsets : + StaticTileOffsetRange(shape, *targetShape)) { + + for (auto [idx, oldOff, offset] : + llvm::zip(validIdxes, oldOffsets, offsets)) + mixedOffsets[idx] = addi(oldOff, offset); + + auto newOp = rewriter.create( + loc, newTdescTy, op.getSource(), mixedOffsets, op.getMixedSizes(), + op.getMixedStrides()); + newOps.push_back(newOp); + } + Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter); + rewriter.replaceOp(op, castOp); + + return success(); + } +}; + +struct UnrollUpdateNdOffsetOp : public UnrollPattern { + using UnrollPattern::UnrollPattern; + LogicalResult matchAndRewrite(xegpu::UpdateNdOffsetOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + xegpu::TensorDescType tdescTy = op.getTensorDescType(); + ArrayRef shape = tdescTy.getShape(); + + std::optional> targetShape = getTargetShape(op); + if (!targetShape || llvm::equal(*targetShape, shape)) + return failure(); + + SmallVector convertedTdescTypes = + getUnrolledTypes(tdescTy, *targetShape); + SmallVector convertedTdesc = pack( + op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); + + SmallVector newOps; + for (auto t : convertedTdesc) { + auto newOp = rewriter.create( + loc, t.getType(), t, op.getOffsets(), op.getConstOffsets()); + newOps.push_back(newOp); + } + Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); + rewriter.replaceOp(op, castOp); + return success(); + } +}; + +struct UnrollPrefetchNdOp : public UnrollPattern { + using UnrollPattern::UnrollPattern; + LogicalResult matchAndRewrite(xegpu::PrefetchNdOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + xegpu::TensorDescType tdescTy = op.getTensorDescType(); + ArrayRef shape = tdescTy.getShape(); + + std::optional> targetShape = getTargetShape(op); + if (!targetShape || llvm::equal(*targetShape, shape)) + return failure(); + + SmallVector convertedTdescTypes = + getUnrolledTypes(tdescTy, *targetShape); + SmallVector convertedTdesc = pack( + op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); + + for (auto t : convertedTdesc) + rewriter.create(loc, TypeRange(), t, op->getAttrs()); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct UnrollLoadNdOp : public UnrollPattern { + using UnrollPattern::UnrollPattern; + LogicalResult matchAndRewrite(xegpu::LoadNdOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + VectorType valueTy = op.getType(); + xegpu::TensorDescType tdescTy = op.getTensorDescType(); + ArrayRef shape = tdescTy.getShape(); + + std::optional> targetShape = getTargetShape(op); + if (!targetShape || llvm::equal(*targetShape, shape)) + return failure(); + + Type elemTy = tdescTy.getElementType(); + VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); + + SmallVector convertedTdescTypes = + getUnrolledTypes(tdescTy, *targetShape); + SmallVector convertedTdescs = pack( + op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); + + SmallVector newOps; + for (auto t : convertedTdescs) { + auto newOp = + rewriter.create(loc, newValueTy, t, op->getAttrs()); + newOps.push_back(newOp); + } + + Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter); + + rewriter.replaceOp(op, castOp); + return success(); + } +}; + +struct UnrollStoreNdOp : public UnrollPattern { + using UnrollPattern::UnrollPattern; + LogicalResult matchAndRewrite(xegpu::StoreNdOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + VectorType valueTy = op.getValueType(); + xegpu::TensorDescType tdescTy = op.getTensorDescType(); + ArrayRef shape = tdescTy.getShape(); + + std::optional> targetShape = getTargetShape(op); + if (!targetShape || llvm::equal(*targetShape, shape)) + return failure(); + + SmallVector convertedValTypes = + getUnrolledTypes(valueTy, *targetShape); + SmallVector convertedTdescTypes = + getUnrolledTypes(tdescTy, *targetShape); + + SmallVector convertedValues = + pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter); + SmallVector convertedTdescs = pack( + op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter); + + for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) + rewriter.create(loc, v, t, op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct UnrollDpasOp : public UnrollPattern { + using UnrollPattern::UnrollPattern; + LogicalResult matchAndRewrite(xegpu::DpasOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // expecting every operands is a 2D Vector + if (llvm::any_of(op->getOperandTypes(), [&](Type type) { + auto vecTy = dyn_cast(type); + return !vecTy || vecTy.getRank() != 2; + })) + return failure(); + + // A vector of 3 elements should be returned, representing M, K, N + // respectively. + std::optional> targetShape = getTargetShape(op); + if (!targetShape || targetShape->size() != 3) + return failure(); + auto M = (*targetShape)[0]; + auto K = (*targetShape)[1]; + auto N = (*targetShape)[2]; + + int64_t aBlockSize[2] = {M, K}; + int64_t bBlockSize[2] = {K, N}; + int64_t cBlockSize[2] = {M, N}; + + auto packWrapper = [&](TypedValue val, + ArrayRef blockSize) { + VectorType type = val.getType(); + std::optional> grids = + computeShapeRatio(type.getShape(), blockSize); + assert(grids && "Expecting grids to be computed."); + auto numNewOps = computeProduct(*grids); + if (numNewOps == 1) + return SmallVector({val}); + VectorType newVecTy = type.cloneWith(blockSize, type.getElementType()); + SmallVector convertedTypes(numNewOps, newVecTy); + SmallVector values = + pack(val, convertedTypes, blockSize, loc, rewriter); + return values; + }; + + auto a = op.getLhs(); + auto b = op.getRhs(); + auto c = op.getAcc(); + + auto aShape = a.getType().getShape(); + auto bShape = b.getType().getShape(); + + SmallVector aVals, bVals, cVals; + aVals = packWrapper(a, aBlockSize); + bVals = packWrapper(b, bBlockSize); + + if (c) + cVals = packWrapper(c, cBlockSize); + + // Skip the operation if every operand has an invalid blocking size (empty) + // or if the original shape matches the blocking size (size == 1). + auto ranges = c ? SmallVector({aVals, bVals, cVals}) + : SmallVector({aVals, bVals}); + if (llvm::any_of(ranges, [](auto &v) { return v.size() == 0; }) || + llvm::all_of(ranges, [](auto &v) { return v.size() == 1; })) + return failure(); + + VectorType resultTy = op.getResult().getType(); + auto vecTy = VectorType::get(cBlockSize, resultTy.getElementType()); + + int64_t mIters = aShape[0] / M; + int64_t kIters = aShape[1] / K; + int64_t nIters = bShape[1] / N; + + SmallVector newOps; + for (int64_t i = 0; i < mIters; ++i) { + for (int64_t j = 0; j < nIters; ++j) { + Value tmpC; + if (c) + tmpC = cVals[i * nIters + j]; // init with acc + + for (int64_t k = 0; k < kIters; ++k) { + Value aVec = aVals[i * kIters + k]; + Value bVec = bVals[k * nIters + j]; + SmallVector operands({aVec, bVec}); + if (tmpC) + operands.push_back(tmpC); + + tmpC = rewriter.create(loc, vecTy, operands, + op->getAttrs()); + } + newOps.push_back(tmpC); + } + } + Value castOp = unpack(newOps, resultTy, cBlockSize, loc, rewriter); + rewriter.replaceOp(op, castOp); + return success(); + } +}; + +} // namespace + +void mlir::xegpu::populateXeGPUUnrollPatterns( + RewritePatternSet &patterns, const xegpu::UnrollOptions &options) { + patterns.add( + patterns.getContext(), options); +} diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir new file mode 100644 index 0000000000000..b911bb3bbdc1c --- /dev/null +++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir @@ -0,0 +1,161 @@ +// RUN: mlir-opt --test-xegpu-unrolling-patterns -split-input-file %s | FileCheck %s + +gpu.module @test { + + // CHECK-LABEL: test_create_nd_tdesc + // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32> + // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast + // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>, + // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>, + // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + // CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout> {__xegpu_blocking_tile_shape__ = array, __xegpu_blocking_unpack__} + gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + } + + //----- + + // CHECK-LABEL: test_create_nd_tdesc_1d + // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> + // CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32> + // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast + // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32> + // CHECK-SAME: to !xegpu.tensor_desc<32xf32, #xegpu.layout> {__xegpu_blocking_tile_shape__ = array, __xegpu_blocking_unpack__} + gpu.func @test_create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout> { + %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout> + gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.layout> + } + + //----- + + // CHECK-LABEL: test_update_nd_tdesc + // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32> + // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-COUNT-6: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf32> + gpu.func @test_update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + gpu.return %update : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + } + + //----- + + // CHECK-LABEL: test_update_nd_tdesc_1d + // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> + // CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32> + // CHECK-COUNT-2: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16xf32> + gpu.func @test_update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout> { + %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout> + %update = xegpu.update_nd_offset %tdesc, [32] : !xegpu.tensor_desc<32xf32, #xegpu.layout> + gpu.return %update : !xegpu.tensor_desc<32xf32, #xegpu.layout> + } + + //----- + + // CHECK-LABEL: test_prefetch_nd_tdesc + // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32> + // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-COUNT-6: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> + gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + gpu.return + } + + //----- + + // CHECK-LABEL: test_prefetch_nd_tdesc_1d + // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> + // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32> + // CHECK-COUNT-4: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<16xf32> + gpu.func @test_prefetch_nd_tdesc_1d(%src: memref<64xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout> + xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<64xf32, #xegpu.layout> + gpu.return + } + + //----- + // CHECK-LABEL: test_load_nd + // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32> + // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + // CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32> + gpu.func @test_load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout> -> vector<24x32xf32> + gpu.return %ld : vector<24x32xf32> + } + + //----- + + // CHECK-LABEL: test_load_nd_1d + // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> + // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32> + // CHECK-COUNT-4: [[ld:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16xf32> -> vector<16xf32> + // CHECK-COUNT-4: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<16xf32> into vector<64xf32> + gpu.func @test_load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> { + %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout> + %data = xegpu.load_nd %tdesc: !xegpu.tensor_desc<64xf32, #xegpu.layout> -> vector<64xf32> + gpu.return %data : vector<64xf32> + } + + //----- + + // CHECK-LABEL: test_store_nd + // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32> + // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + // CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.func @test_store_nd(%src: memref<24x32xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %data = arith.constant dense<9.0> : vector<24x32xf32> + xegpu.store_nd %data, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + gpu.return + } + + //----- + + // CHECK-LABEL: test_store_nd_1d + // CHECK-SAME: [[arg0:%.+]]: memref<64xf32> + // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32> + // CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32> + gpu.func @test_store_nd_1d(%src: memref<64xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout> + %data = arith.constant dense<9.0> : vector<64xf32> + xegpu.store_nd %data, %tdesc: vector<64xf32>, !xegpu.tensor_desc<64xf32, #xegpu.layout> + gpu.return + } + + //----- + + // CHECK-LABEL: test_createNd_loadNd_storeNd + // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32> + //CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32> + //CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32> + //CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32> + //CHECK: [[add:%.+]] = arith.addf {{.*}} : vector<24x32xf32> + //CHECK-COUNT-6: [[extract:%.+]] = vector.extract_strided_slice {{.*}} : vector<24x32xf32> to vector<8x16xf32> + //CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32> + gpu.func @test_createNd_loadNd_storeNd(%src: memref<24x32xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %data = arith.constant dense<9.0> : vector<24x32xf32> + %ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout> -> vector<24x32xf32> + %add = arith.addf %data, %ld : vector<24x32xf32> + xegpu.store_nd %add, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + gpu.return + } + + //----- + + // CHECK-LABEL: test_dpas + // CHECK-SAME: [[arg0:%.+]]: vector<32x32xf16>, [[arg1:%.+]]: vector<32x32xf16> + //CHECK-COUNT-8: [[extract1:%.+]] = vector.extract_strided_slice [[arg0]] {{.*}} : vector<32x32xf16> to vector<8x16xf16> + //CHECK-COUNT-4: [[extract2:%.+]] = vector.extract_strided_slice [[arg1]] {{.*}} : vector<32x32xf16> to vector<16x16xf16> + //CHECK-COUNT-16: [[dpas:%.+]] = xegpu.dpas {{.*}} -> vector<8x16xf32> + //CHECK-COUNT-8: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32> + gpu.func @test_dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> { + %c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32> + gpu.return %c : vector<32x32xf32> + } +} diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt index 29fb4441a24fd..a8fd70e6397a5 100644 --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -22,3 +22,4 @@ add_subdirectory(TestDyn) add_subdirectory(Tosa) add_subdirectory(Transform) add_subdirectory(Vector) +add_subdirectory(XeGPU) diff --git a/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt new file mode 100644 index 0000000000000..5236d8765eac8 --- /dev/null +++ b/mlir/test/lib/Dialect/XeGPU/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library(MLIRXeGPUTestPasses + TestXeGPUTransforms.cpp + + EXCLUDE_FROM_LIBMLIR +) + +mlir_target_link_libraries(MLIRXeGPUTestPasses PUBLIC + MLIRAffineUtils + MLIRIR + MLIRMemRefDialect + MLIRXeGPUDialect + MLIRPass + MLIRTransforms + MLIRGPUDialect + MLIRXeGPUTransforms +) diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp new file mode 100644 index 0000000000000..eaa3b988cad82 --- /dev/null +++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp @@ -0,0 +1,124 @@ +//===- TestXeGPUTransforms.cpp -- Test Vector transforms and lowerings ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::xegpu; + +namespace { + +struct TestXeGPUUnrollingPatterns + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUUnrollingPatterns) + + StringRef getArgument() const final { + return "test-xegpu-unrolling-patterns"; + } + + StringRef getDescription() const final { + return "Test lowering patterns to unroll ops in the xegpu dialect"; + } + + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + } + + TestXeGPUUnrollingPatterns() = default; + TestXeGPUUnrollingPatterns(const TestXeGPUUnrollingPatterns &pass) + : PassWrapper(pass) {} + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + xegpu::UnrollOptions options; + options.setNativeShapeFn( + [&](Operation *op) -> std::optional> { + if (isa(op)) { + xegpu::TensorDescType tdescTy; + if (auto createNdOp = dyn_cast(op)) { + tdescTy = createNdOp.getType(); + } else if (auto updateNdOp = + dyn_cast(op)) { + tdescTy = updateNdOp.getTensorDescType(); + } else if (auto prefetchNdOp = dyn_cast(op)) { + tdescTy = prefetchNdOp.getTensorDescType(); + } else if (auto loadNdOp = dyn_cast(op)) { + tdescTy = loadNdOp.getTensorDescType(); + } else if (auto storeNdOp = dyn_cast(op)) { + tdescTy = storeNdOp.getTensorDescType(); + } + + if (auto layout = tdescTy.getLayoutAttr()) { + auto inst_data = layout.getInstData(); + if (inst_data && layout.isSgLayout()) + return SmallVector(inst_data.asArrayRef().begin(), + inst_data.asArrayRef().end()); + } + } + + if (isa(op)) + return SmallVector{8, 16, 16}; + + return std::nullopt; + }); + + options.setUnrolledTypesFn( + [&](ShapedType type, ArrayRef tileShape) -> SmallVector { + Type elemTy = type.getElementType(); + Type newTy; + + // TensorDescType needs to drop the inst_data field in the layout + // attribute + if (auto tdescTy = dyn_cast(type)) { + Attribute encoding = tdescTy.getEncoding(); + auto layout = llvm::dyn_cast_if_present( + tdescTy.getLayout()); + if (layout) { + if (layout.getLaneLayout() == nullptr) + layout = xegpu::LayoutAttr(); + else + layout = layout.dropInstData(); + } + newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, + layout); + } else { + newTy = type.clone(tileShape, elemTy); + } + + std::optional> ratio = + computeShapeRatio(type.getShape(), tileShape); + assert(ratio && "Expecting the ratio to be valid."); + return SmallVector(computeProduct(*ratio), newTy); + }); + + RewritePatternSet patterns(ctx); + + populateXeGPUUnrollPatterns(patterns, options); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestXeGPULowerings() { + PassRegistration(); +} +} // namespace test +} // namespace mlir \ No newline at end of file diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index a5a442909fc6d..3220dca282eac 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -46,6 +46,7 @@ if(MLIR_INCLUDE_TESTS) MLIRTilingInterfaceTestPasses MLIRTosaTestPasses MLIRVectorTestPasses + MLIRXeGPUTestPasses MLIRTestVectorToSPIRV MLIRLLVMTestPasses ) diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 344576a44ca41..cdcf59b2add13 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -158,6 +158,7 @@ void registerTestVectorLowerings(); void registerTestVectorReductionToSPIRVDotProd(); void registerTestVulkanRunnerPipeline(); void registerTestWrittenToPass(); +void registerTestXeGPULowerings(); #if MLIR_ENABLE_PDL_IN_PATTERNMATCH void registerTestDialectConversionPasses(); void registerTestPDLByteCodePass(); @@ -301,6 +302,7 @@ void registerTestPasses() { mlir::test::registerTestVectorReductionToSPIRVDotProd(); mlir::test::registerTestVulkanRunnerPipeline(); mlir::test::registerTestWrittenToPass(); + mlir::test::registerTestXeGPULowerings(); #if MLIR_ENABLE_PDL_IN_PATTERNMATCH mlir::test::registerTestDialectConversionPasses(); mlir::test::registerTestPDLByteCodePass();