diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td index 3e81f2d0ed786..6f585f9ceb29b 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// - #ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD #define MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD @@ -18,9 +17,7 @@ def XeGPUFoldAliasOps : Pass<"xegpu-fold-alias-ops"> { The pass folds aliasing ops into XeGPU ops that they operate on the original source references. }]; - let dependentDialects = [ - "memref::MemRefDialect", "xegpu::XeGPUDialect" - ]; + let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect"]; } def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> { @@ -28,14 +25,24 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> { let description = [{ The pass distributes subgroup level (SIMD) XeGPU ops to work items. }]; - let dependentDialects = [ - "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect" - ]; - let options = [ - Option<"printOnly", "print-analysis-only", "bool", - /*default=*/"false", - "Print the result of the subgroup map propagation analysis and exit."> - ]; + let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect", + "vector::VectorDialect"]; + let options = [Option< + "printOnly", "print-analysis-only", "bool", + /*default=*/"false", + "Print the result of the subgroup map propagation analysis and exit.">]; +} + +def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute"> { + let summary = "Transform WorkGroup level XeGPU code to SubGroup level"; + let description = [{ + This transform pass distributes the workgroup level computation to + multiple subgroups based on the sg_layout and sg_data attributes. + }]; + + let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect", + "vector::VectorDialect", "arith::ArithDialect", + "gpu::GPUDialect", "index::IndexDialect"]; } #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h index 559cc3ece62fb..44b81796b1313 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h @@ -62,6 +62,7 @@ void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns); /// Appends patterns for XeGPU SIMT distribution into `patterns`. void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns); +void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns); /// Collect a set of patterns to unroll xegpu operations to a smaller shapes. /// Users can control whether an operation to be unrolled or not, as well as diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt index 892eb791c46e7..837303b04e9d7 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms XeGPUFoldAliasOps.cpp XeGPUSubgroupDistribute.cpp XeGPUUnroll.cpp + XeGPUWgToSgDistribute.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp new file mode 100644 index 0000000000000..3bf76af674ba0 --- /dev/null +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -0,0 +1,378 @@ +//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/XeGPU/Transforms/Passes.h" + +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Transforms/Transforms.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace xegpu { +#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE +#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc" +} // namespace xegpu +} // namespace mlir + +using namespace mlir; + +namespace { + +/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor +/// from a workgroup descriptor. It replaces the offsets and sizes with +/// appropriate values for the subgroup. +/// It uses round-robin assignment to distribute the work to the subgroups. +/// Following create_nd_desc operation:, +/// %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32> +/// -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout> +/// is converted to 9 subgroup level operations based on the sg_layout & +/// sg_data: +/// %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> -> +/// !xegpu.tensor_desc<2x2xf32, #xegpu.layout> +/// +/// The sg_layout and sg_data attributes are dropped after the pass as they are +/// no longer needed. +/// +/// 24x24 matrix distribution example: +/// sg_layout = [4, 4], sg_data = [2, 2] +/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit. +/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i] +/// +/// +------------------------+ +/// | 8x8 | 8x8 | 8x8 | <- 3 tiles across +/// |-----+-----+-----| +/// | 8x8 | 8x8 | 8x8 | <- 3 tiles down +/// |-----+-----+-----| +/// | 8x8 | 8x8 | 8x8 | +/// +------------------------+ +/// +/// Each 8x8 tile is further subdivided among subgroups: +/// +------------------------+ +/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups across (each handles 2 columns) +/// | 2x2 2x2 2x2 2x2 | <- 4 subgroups down (each handles 2 rows) +/// | 2x2 2x2 2x2 2x2 | +/// | 2x2 2x2 2x2 2x2 | +/// +------------------------+ +/// +/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be +/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations. + +/// The pass currently has entire distribution logic in the WgToSgCreateNdOp +/// pattern and all the other ops just follow. +/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the +/// ops in the pass. +struct WgToSgCreateNdOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + // Calculate offset for each subgroup + SmallVector + calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc, + const SmallVector &originalOffsets, + const SmallVector &localOffset, + const SmallVector &distUnitBaseAddr, + const SmallVector &distUnitShape) const { + assert(localOffset.size() == distUnitBaseAddr.size() && + "localOffset and distUnitBaseAddr must have the same rank"); + + SmallVector globalOffsets(originalOffsets.begin(), + originalOffsets.end()); + size_t rank = localOffset.size(); + for (size_t i = 0; i < rank; ++i) { + size_t dimIdx = originalOffsets.size() - rank + i; + Value constOffset = + rewriter.create(loc, distUnitBaseAddr[i]); + Value offset = + rewriter.createOrFold(loc, localOffset[i], constOffset); + Value modValue = + rewriter.create(loc, distUnitShape[i]); + Value offsetMod = + rewriter.createOrFold(loc, offset, modValue); + Value origOffset = getValueOrCreateConstantIndexOp( + rewriter, loc, originalOffsets[dimIdx]); + Value globalOffset = + rewriter.createOrFold(loc, origOffset, offsetMod); + globalOffsets[dimIdx] = globalOffset; + } + + return globalOffsets; + } + + LogicalResult + matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *ctx = op.getContext(); + xegpu::TensorDescType tdescTy = op.getType(); + auto layout = dyn_cast(tdescTy.getLayout()); + if (!layout) + return failure(); + Type elemTy = tdescTy.getElementType(); + ArrayRef wgShape = tdescTy.getShape(); + // sgLayout must be present for workgroup-level distribution. + SmallVector sgLayout; + if (auto sgLayoutAttr = layout.getSgLayout()) + sgLayout = llvm::to_vector_of(sgLayoutAttr.asArrayRef()); + else + return rewriter.notifyMatchFailure( + op, "sgLayout attribute is required in layout"); + + SmallVector sgShape; + if (auto sgDataAttr = layout.getSgData()) { + sgShape = llvm::to_vector_of(sgDataAttr.asArrayRef()); + } else { + assert(wgShape.size() == sgLayout.size() && + "sgLayout and wgShape must have the same rank"); + sgShape.reserve(wgShape.size()); + for (size_t i = 0; i < wgShape.size(); ++i) { + assert(sgLayout[i] != 0 && "sgLayout elements must be non-zero"); + sgShape.push_back(wgShape[i] / sgLayout[i]); + } + } + + // TODO : Handle order attribute + // Get the subgroup ID + auto linearSgId = + rewriter.create(loc, /*upper_bound=*/nullptr); + + // Create constants for layout dimensions + SmallVector sgLayoutDim(sgLayout.size()); + SmallVector sgDataDim(sgShape.size()); + + for (size_t i = 0; i < sgLayout.size(); i++) { + sgLayoutDim[i] = + rewriter.create(loc, sgLayout[i]); + sgDataDim[i] = rewriter.create(loc, sgShape[i]); + } + + auto deLinearizeSgId = + affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim); + if (failed(deLinearizeSgId)) + return failure(); + SmallVector sgIds = *deLinearizeSgId; + + // Calculate distribution unit shape and local offsets for subgroup + SmallVector distUnitShape(sgLayout.size()); + SmallVector localOffset(sgLayout.size()); + for (size_t i = 0; i < sgLayout.size(); i++) { + distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]); + localOffset[i] = + rewriter.createOrFold(loc, sgIds[i], sgDataDim[i]); + } + + SmallVector originalOffsets = op.getMixedOffsets(); + + xegpu::TensorDescType newTdescTy = + xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), + layout.dropSgLayoutAndData()); + SmallVector newCreateNdOps; + for (SmallVector distUnitBaseAddr : + StaticTileOffsetRange(wgShape, distUnitShape)) { + SmallVector globalOffsets = + calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset, + distUnitBaseAddr, distUnitShape); + + auto newCreateNdOp = rewriter.create( + loc, newTdescTy, op.getSource(), globalOffsets, op.getMixedSizes(), + op.getMixedStrides()); + newCreateNdOps.push_back(newCreateNdOp); + } + + rewriter.replaceOpWithMultiple(op, {newCreateNdOps}); + return success(); + } +}; + +/// This pattern transforms the LoadNdOp to load subgroup data. +struct WgToSgLoadNdOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector newLoadOps; + for (auto src : adaptor.getTensorDesc()) { + xegpu::TensorDescType tdescTy = + dyn_cast(src.getType()); + ArrayRef srcShape = tdescTy.getShape(); + VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType()); + auto newLoadOp = rewriter.create(op.getLoc(), newResTy, + src, op->getAttrs()); + newLoadOps.push_back(newLoadOp); + } + rewriter.replaceOpWithMultiple(op, {newLoadOps}); + return mlir::success(); + } +}; + +/// This pattern transforms the StoreNdOp to store to a subgroup descriptor +/// It creates a StoreNdOp op to store the updated values to the new subgroup +/// src tensor descriptors. +struct WgToSgStoreNdOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc())) + rewriter.create(op.getLoc(), v, t, op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); + + rewriter.eraseOp(op); + return success(); + } +}; + +/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a +/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the +/// offsets of the new subgroup src tensor descriptors. +struct WgToSgUpdateNdOffsetOp + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallVector newUpdateTileOffsetOps; + for (auto tDesc : adaptor.getTensorDesc()) { + auto newUpdateTileOffsetOp = rewriter.create( + op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(), + op.getConstOffsets()); + newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp); + } + + rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps}); + return success(); + } +}; + +/// This pattern transforms the DpasOp to work at subgroup level. +struct WgToSgDpasOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + VectorType resultTy = op.getResult().getType(); + if (resultTy.getRank() != 2) + return failure(); + + auto originalLayout = + llvm::dyn_cast_or_null(op->getAttr("layout")); + if (!originalLayout) + return failure(); + + SmallVector newDpasOps; + size_t i = 0; + for (auto aVec : adaptor.getLhs()) { + for (auto bVec : adaptor.getRhs()) { + llvm::SmallVector operands({aVec, bVec}); + Value tmpC; + if (op.getAcc()) { + tmpC = adaptor.getAcc()[i++]; + operands.push_back(tmpC); + } + + ArrayRef aVecShape = + llvm::cast(aVec.getType()).getShape(); + ArrayRef bVecShape = + llvm::cast(bVec.getType()).getShape(); + VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]}, + resultTy.getElementType()); + tmpC = rewriter.create( + loc, resTy, operands, + llvm::ArrayRef( + {"layout_result_0", originalLayout.dropSgLayoutAndData()})); + newDpasOps.push_back(tmpC); + } + } + rewriter.replaceOpWithMultiple(op, {newDpasOps}); + return success(); + } +}; + +/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data. +struct WgToSgPrefetchNdOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + for (auto src : adaptor.getTensorDesc()) + rewriter.create(op.getLoc(), TypeRange(), src, + op->getAttrs()); + rewriter.eraseOp(op); + return success(); + } +}; + +} // namespace + +namespace mlir { +namespace xegpu { +void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} +} // namespace xegpu +} // namespace mlir + +namespace { +struct XeGPUWgToSgDistributePass + : public xegpu::impl::XeGPUWgToSgDistributeBase { + void runOnOperation() override; +}; +} // namespace + +void XeGPUWgToSgDistributePass::runOnOperation() { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + ConversionTarget target(*ctx); + + auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType { + if (auto createOp = dyn_cast(op)) + return createOp.getType(); + if (auto loadOp = dyn_cast(op)) + return loadOp.getTensorDescType(); + if (auto storeOp = dyn_cast(op)) + return storeOp.getTensorDescType(); + if (auto updateOp = dyn_cast(op)) + return updateOp.getType(); + if (auto prefetchOp = dyn_cast(op)) + return prefetchOp.getTensorDescType(); + return xegpu::TensorDescType(); + }; + + auto isLegal = [&](xegpu::LayoutAttr layout) -> bool { + return !layout || layout.getSgLayout() == nullptr; + }; + + target.addDynamicallyLegalOp([=](Operation *op) -> bool { + auto tdescTy = getTensorDescType(op); + auto layout = dyn_cast_or_null(tdescTy.getLayout()); + return isLegal(layout); + }); + + target.addDynamicallyLegalOp([=](xegpu::DpasOp op) -> bool { + auto layout = dyn_cast_or_null(op->getAttr("layout")); + return isLegal(layout); + }); + + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + + xegpu::populateXeGPUWgToSgDistributePatterns(patterns); + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + return signalPassFailure(); +} diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir new file mode 100644 index 0000000000000..bee026eb2084d --- /dev/null +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir @@ -0,0 +1,105 @@ +// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s + +gpu.module @test_round_robin_assignment { + // CHECK-LABEL: test_create_nd_tdesc + // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> + gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) { + // CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout> + // CHECK-NOT: xegpu.create_nd_tdesc + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + gpu.return + } + + // CHECK-LABEL: test_load_nd_tdesc + // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> + gpu.func @test_load_nd_tdesc(%src: memref<24x32xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + // CHECK-COUNT-12: xegpu.load_nd %{{.*}} + // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout> + // CHECK-SAME-COUNT-12: -> vector<2x2xf32> + // CHECK-NOT: xegpu.load_nd + %load = xegpu.load_nd %tdesc + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + gpu.return + } + + // CHECK-LABEL: test_store_nd + // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> + gpu.func @test_store_nd(%src: memref<24x32xf32>) { + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}} + // CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout> + // CHECK-NOT : xegpu.store_nd + %load = xegpu.load_nd %tdesc + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + xegpu.store_nd %load, %tdesc + : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + gpu.return + } + + // CHECK-LABEL: test_update_nd + // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> + gpu.func @test_update_nd(%src: memref<24x32xf32>){ + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + // CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16] + // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout> + // CHECK-NOT: xegpu.update_nd_offset + %update = xegpu.update_nd_offset %tdesc, [0, 16] + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + gpu.return + } + + // CHECK-LABEL: test_dpas + // CHECK-SAME: (%[[ARG_0:.*]]: memref<8x8xf32>, %[[ARG_1:.*]]: memref<8x8xf32>, %[[ARG_2:.*]]: memref<8x8xf32>) + gpu.func @test_dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) { + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32> + // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout> + // CHECK-NOT: xegpu.create_nd_tdesc + // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32> + // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout> + // CHECK-NOT: xegpu.create_nd_tdesc + // CHECK-COUNT-4: xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32> + // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout> + // CHECK-NOT: xegpu.create_nd_tdesc + // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}} + // CHECK-SAME-COUNT-16: {layout = #xegpu.layout} + // CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32> + // CHECK-NOT: xegpu.dpas + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32> + -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<8x8xf32, #xegpu.layout> + -> vector<8x8xf32> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32> + -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<8x8xf32, #xegpu.layout> + -> vector<8x8xf32> + %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32> + -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout> + %dpas = xegpu.dpas %load_a, %load_b + {layout = #xegpu.layout} + : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32> + gpu.return + } + + // CHECK-LABEL: test_prefetch_nd_tdesc + // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> + gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) { + // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}} + // CHECK-SAME-COUNT-12 : !xegpu.tensor_desc<2x2xf32, #xegpu.layout> + // CHECK-NOT: xegpu.prefetch_nd + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + xegpu.prefetch_nd %tdesc + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + gpu.return + } +} diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir new file mode 100644 index 0000000000000..7e89ada934071 --- /dev/null +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -0,0 +1,172 @@ +// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s + +//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)> +//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)> +gpu.module @test_1_1_assignment { + // CHECK-LABEL: test_create_nd_tdesc + // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> + gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) { + // CHECK: %[[SGID:.*]] = gpu.subgroup_id + // CHECK: %[[C12:.*]] = arith.constant 12 : index + // CHECK: %[[C4:.*]] = arith.constant 4 : index + // CHECK: %[[C8:.*]] = arith.constant 8 : index + // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]] + // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]] + // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]] + // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]] + // CHECK: %[[C24:.*]] = arith.constant 24 : index + // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C24]] + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0]] + // CHECK: %[[C32:.*]] = arith.constant 32 : index + // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C32]] + // CHECK: %[[C0_1:.*]] = arith.constant 0 : index + // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_1]] + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout> + // CHECK: gpu.return + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + gpu.return + } + + // CHECK-LABEL: test_load_nd_tdesc + // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> + gpu.func @test_load_nd_tdesc(%src: memref<24x32xf32>) { + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout> + // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] + // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout> + // CHECK-SAME: -> vector<12x8xf32> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + gpu.return + } + + // CHECK-LABEL: test_store_nd + // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> + gpu.func @test_store_nd(%src: memref<24x32xf32>) { + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout> + // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] + // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout> + // CHECK-SAME: -> vector<12x8xf32> + // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]] + // CHECK-SAME: : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + xegpu.store_nd %load, %tdesc + : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + gpu.return +} + +// CHECK-LABEL: test_update_nd +// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> +gpu.func @test_update_nd(%src: memref<24x32xf32>){ + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout> + // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16] + // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %update = xegpu.update_nd_offset %tdesc, [0, 16] + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + gpu.return +} + +// CHECK-LABEL: test_dpas +// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> +// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32> +gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { + // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> + // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout> + // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] + // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout> + // CHECK-SAME: -> vector<12x8xf32> + // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout> + // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] + // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout> + // CHECK-SAME: -> vector<8x12xf32> + // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] + // CHECK-SAME: {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32> + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> + -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<32x24xf32, #xegpu.layout> + -> vector<32x24xf32> + %dpas = xegpu.dpas %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + gpu.return + } + + +// CHECK-LABEL: test_dpas_no_sg_data +// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> +// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32> +gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) { + // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> + // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout> + // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] + // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout> + // CHECK-SAME: -> vector<12x8xf32> + // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout> + // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] + // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout> + // CHECK-SAME: -> vector<8x12xf32> + // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] + // CHECK-SAME: {layout_result_0 = #xegpu.layout} + // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32> + %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + %load_a = xegpu.load_nd %tdesc_a + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + -> vector<24x32xf32> + %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> + -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout> + %load_b = xegpu.load_nd %tdesc_b + : !xegpu.tensor_desc<32x24xf32, #xegpu.layout> + -> vector<32x24xf32> + %dpas = xegpu.dpas %load_a, %load_b + {layout = #xegpu.layout} + : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + gpu.return + } + + // CHECK-LABEL: test_prefetch_nd_tdesc + // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32> + gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) { + // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> + // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout> + // CHECK: xegpu.prefetch_nd %[[TDESC]] + // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout> + %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> + -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + xegpu.prefetch_nd %tdesc + : !xegpu.tensor_desc<24x32xf32, #xegpu.layout> + gpu.return + } + + // CHECK-LABEL: test_dpas_with_no_create_nd_desc + gpu.func @test_dpas_with_no_create_nd_desc(%a: vector<24x32xf32>, %b: vector<32x24xf32>) { + // CHECK-NOT: vector<12x12xf32> + %dpas = xegpu.dpas %a, %b + {layout = #xegpu.layout} + : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32> + gpu.return + } +}