diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 7e106c90ada0..b525ae238305 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -86,6 +86,7 @@ iree_compiler_cc_library( "ConvertAccGEMMToGEMMPass.cpp", "ConvertBf16ArithToF32.cpp", "ConvertBf16ToUInt16Buffers.cpp", + "ConvertForallToGenericNestWorkgroup.cpp", "ConvertToDestinationPassingStylePass.cpp", "ConvertUnsupportedFloatArithPass.cpp", "ConvertWorkgroupForallToPCF.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 48497a85be7f..d741f93d98c4 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -79,6 +79,7 @@ iree_cc_library( "ConvertAccGEMMToGEMMPass.cpp" "ConvertBf16ArithToF32.cpp" "ConvertBf16ToUInt16Buffers.cpp" + "ConvertForallToGenericNestWorkgroup.cpp" "ConvertToDestinationPassingStylePass.cpp" "ConvertUnsupportedFloatArithPass.cpp" "ConvertWorkgroupForallToPCF.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertForallToGenericNestWorkgroup.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertForallToGenericNestWorkgroup.cpp new file mode 100644 index 000000000000..e52a39d3778b --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/ConvertForallToGenericNestWorkgroup.cpp @@ -0,0 +1,70 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/PCF/IR/PCF.h" +#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_CONVERTFORALLTOGENERICNESTWORKGROUPPASS +#include "iree/compiler/Codegen/Common/Passes.h.inc" + +namespace { + +/// Returns true if the forall op has WorkgroupMappingAttr mapping attributes. +static bool hasWorkgroupMapping(scf::ForallOp forallOp) { + std::optional mapping = forallOp.getMapping(); + if (!mapping || mapping->empty()) { + return false; + } + return llvm::all_of(mapping.value(), + llvm::IsaPred); +} + +struct ConvertForallToGenericNestWorkgroupPass final + : public impl::ConvertForallToGenericNestWorkgroupPassBase< + ConvertForallToGenericNestWorkgroupPass> { + using Base::Base; + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + + // Always use linearized workgroup scope (1 id). + // Interface is implemented via external models hence the cast. + auto scope = cast( + IREE::Codegen::WorkgroupScopeAttr::get(ctx, /*linearize=*/true)); + + SmallVector scopes = {scope}; + + IRRewriter rewriter(ctx); + SmallVector forallOps; + getOperation()->walk([&](scf::ForallOp forallOp) { + // Only convert foralls with workgroup mapping attributes. + if (hasWorkgroupMapping(forallOp)) { + forallOps.push_back(forallOp); + } + }); + + for (scf::ForallOp forallOp : forallOps) { + rewriter.setInsertionPoint(forallOp); + FailureOr result = + IREE::PCF::convertForallToGenericNest(rewriter, forallOp, scopes); + if (failed(result)) { + forallOp.emitError("failed to convert forall to generic nest"); + return signalPassFailure(); + } + // Replace forall results with generic results. + rewriter.replaceOp(forallOp, result->getResults()); + } + } +}; + +} // namespace +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertWorkgroupForallToPCF.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertWorkgroupForallToPCF.cpp index aeceb0367b6b..ff268af52998 100644 --- a/compiler/src/iree/compiler/Codegen/Common/ConvertWorkgroupForallToPCF.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/ConvertWorkgroupForallToPCF.cpp @@ -53,7 +53,8 @@ ConvertWorkgroupForall::matchAndRewrite(scf::ForallOp op, auto scope = cast( IREE::Codegen::WorkgroupScopeAttr::get(rewriter.getContext(), /*linearize=*/true)); - FailureOr res = convertForallToPCF(rewriter, op, scope, 1); + FailureOr res = + convertForallToPCFLoop(rewriter, op, scope, 1); if (failed(res)) { return failure(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h index ec4b0819ae9a..cd2b5c62005c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.h +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h @@ -18,6 +18,7 @@ #include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" #include "iree/compiler/Codegen/Utils/Utils.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 3acbde36e3ec..f3bd759fd792 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -155,6 +155,23 @@ def ConvertWorkgroupForallToPCFPass let dependentDialects = ["iree_compiler::IREE::PCF::PCFDialect"]; } +def ConvertForallToGenericNestWorkgroupPass + : InterfacePass<"iree-codegen-convert-forall-to-generic-nest-workgroup", + "mlir::FunctionOpInterface"> { + let summary = "Converts scf.forall ops with workgroup mapping to pcf.generic"; + let description = [{ + Converts `scf.forall` ops with `#iree_codegen.workgroup_mapping` attributes + to a `pcf.generic` op using workgroup scope. The pass always linearizes + workgroup IDs to a single dimension. + }]; + let dependentDialects = [ + "iree_compiler::IREE::PCF::PCFDialect", + "arith::ArithDialect", + "affine::AffineDialect", + "scf::SCFDialect" + ]; +} + def CombineLayoutTransformationPass : InterfacePass<"iree-codegen-combine-layout-transformation", "mlir::FunctionOpInterface"> { let summary = diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel index af3c70e2ef28..6d33e2ddef12 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel @@ -33,6 +33,7 @@ iree_lit_test_suite( "convert_accgemm_to_gemm.mlir", "convert_bf16_arith_to_f32.mlir", "convert_bf16_to_uint16_buffers.mlir", + "convert_forall_to_generic_nest_workgroup.mlir", "convert_hal_descriptor_type_to_gpu_address_space.mlir", "convert_to_destination_passing_style.mlir", "convert_unsupported_float_arith.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt index 0295c5c09fb0..89bf7f381800 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt @@ -28,6 +28,7 @@ iree_lit_test_suite( "convert_accgemm_to_gemm.mlir" "convert_bf16_arith_to_f32.mlir" "convert_bf16_to_uint16_buffers.mlir" + "convert_forall_to_generic_nest_workgroup.mlir" "convert_hal_descriptor_type_to_gpu_address_space.mlir" "convert_to_destination_passing_style.mlir" "convert_unsupported_float_arith.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convert_forall_to_generic_nest_workgroup.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convert_forall_to_generic_nest_workgroup.mlir new file mode 100644 index 000000000000..3a71f13f6d67 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/convert_forall_to_generic_nest_workgroup.mlir @@ -0,0 +1,29 @@ +// RUN: iree-opt %s --pass-pipeline="builtin.module(func.func(iree-codegen-convert-forall-to-generic-nest-workgroup))" --allow-unregistered-dialect --split-input-file | FileCheck %s + +// Test that workgroup scope creates 1 id/count pair with linearized scope. + +// CHECK-LABEL: func.func @test_workgroup_scope +// CHECK: pcf.generic +// CHECK-SAME: scope(#iree_codegen.workgroup_scope) +// CHECK: execute(%{{.*}})[%[[ID:.+]]: index, %[[COUNT:.+]]: index] +// Chunk size computed from total iterations / worker count. +// CHECK: %[[CHUNK:.+]] = arith.ceildivui +// Start = id * chunk_size. +// CHECK: %[[START:.+]] = arith.muli %[[ID]], %[[CHUNK]] +// End = min(start + chunk_size, total). +// CHECK: %[[END_RAW:.+]] = arith.addi %[[START]], %[[CHUNK]] +// CHECK: %[[END:.+]] = arith.minui %[[END_RAW]] +// CHECK: scf.forall (%[[IV:.+]]) = (%[[START]]) to (%[[END]]) +// CHECK: "foo.body"(%[[IV]]) +// CHECK: pcf.write_slice +// CHECK: pcf.return +func.func @test_workgroup_scope(%init: tensor<64xf32>) -> tensor<64xf32> { + %result = scf.forall (%i) in (64) shared_outs(%out = %init) -> tensor<64xf32> { + "foo.body"(%i) : (index) -> () + %slice = tensor.extract_slice %out[%i] [1] [1] : tensor<64xf32> to tensor<1xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %slice into %out[%i] [1] [1] : tensor<1xf32> into tensor<64xf32> + } + } {mapping = [#iree_codegen.workgroup_mapping]} + return %result : tensor<64xf32> +} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/ExternalInterfaces/GPUScopeExternalModels.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/ExternalInterfaces/GPUScopeExternalModels.cpp index d31e698f2e6a..4b5778cefdfd 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/ExternalInterfaces/GPUScopeExternalModels.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/ExternalInterfaces/GPUScopeExternalModels.cpp @@ -83,6 +83,11 @@ struct SubgroupScopeModel MLIRContext *context) const { return gpu::AddressSpaceAttr::get(context, gpu::AddressSpace::Workgroup); } + + int64_t getNativeNumProcessorIds(Attribute attr) const { + // SubgroupScope natively provides a single 1D processor ID (subgroup_id). + return 1; + } }; /// External model for LaneScopeAttr implementing ScopeAttrInterface. @@ -133,6 +138,11 @@ struct LaneScopeModel // logic to allocate + subview. return failure(); } + + int64_t getNativeNumProcessorIds(Attribute attr) const { + // LaneScope natively provides a single 1D processor ID (lane_id). + return 1; + } }; } // namespace diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel index ab68bebc5f4d..34e3ef937271 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/BUILD.bazel @@ -60,6 +60,7 @@ iree_compiler_cc_library( name = "GPUTransforms", srcs = [ "CombineBarrierRegions.cpp", + "ConvertForallToGenericNestGPU.cpp", "DistributeInnerTiledToLanes.cpp", "ExpandUndistributedInnerTiles.cpp", "LowerIREEGPUOps.cpp", @@ -77,6 +78,8 @@ iree_compiler_cc_library( ":PassesIncGen", "//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect", "//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect", + "//compiler/src/iree/compiler/Codegen/Dialect/PCF/IR", + "//compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms", "//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect", "//compiler/src/iree/compiler/Codegen/Transforms", "//compiler/src/iree/compiler/Codegen/Utils", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt index 7fc891bb26bf..996a24a4b1c1 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/CMakeLists.txt @@ -49,6 +49,7 @@ iree_cc_library( "Transforms.h" SRCS "CombineBarrierRegions.cpp" + "ConvertForallToGenericNestGPU.cpp" "DistributeInnerTiledToLanes.cpp" "ExpandUndistributedInnerTiles.cpp" "LowerIREEGPUOps.cpp" @@ -84,6 +85,8 @@ iree_cc_library( MLIRVectorUtils iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect + iree::compiler::Codegen::Dialect::PCF::IR + iree::compiler::Codegen::Dialect::PCF::Transforms iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect iree::compiler::Codegen::Transforms iree::compiler::Codegen::Utils diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConvertForallToGenericNestGPU.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConvertForallToGenericNestGPU.cpp new file mode 100644 index 000000000000..beed5859c830 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConvertForallToGenericNestGPU.cpp @@ -0,0 +1,84 @@ +// Copyright 2026 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h" +#include "iree/compiler/Codegen/Dialect/PCF/IR/PCF.h" +#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir::iree_compiler::IREE::GPU { + +#define GEN_PASS_DEF_CONVERTFORALLTOGENERICNESTGPUPASS +#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc" + +namespace { + +/// Returns true if the forall op has gpu::GPUThreadMappingAttr mapping +/// attributes. +static bool hasThreadMapping(scf::ForallOp forallOp) { + std::optional mapping = forallOp.getMapping(); + if (!mapping || mapping->empty()) { + return false; + } + return llvm::all_of(mapping.value(), + llvm::IsaPred); +} + +/// Runs the conversion for forall ops with gpu.thread mapping to nested +/// pcf.generic ops. +static LogicalResult runConversion(Operation *op) { + MLIRContext *ctx = op->getContext(); + + // Create nested scopes: subgroup (outer) + lane (inner). + // Interface is implemented via external models hence the casts. + auto subgroupScope = + cast(SubgroupScopeAttr::get(ctx)); + auto laneScope = cast(LaneScopeAttr::get(ctx)); + + SmallVector scopes = {subgroupScope, laneScope}; + + IRRewriter rewriter(ctx); + SmallVector forallOps; + op->walk([&](scf::ForallOp forallOp) { + // Only convert foralls with gpu.thread mapping attributes. + if (hasThreadMapping(forallOp)) { + forallOps.push_back(forallOp); + } + }); + + for (scf::ForallOp forallOp : forallOps) { + rewriter.setInsertionPoint(forallOp); + FailureOr result = + PCF::convertForallToGenericNest(rewriter, forallOp, scopes); + if (failed(result)) { + forallOp.emitError("failed to convert forall to generic nest"); + return failure(); + } + // Replace forall results with generic results. + rewriter.replaceOp(forallOp, result->getResults()); + } + return success(); +} + +struct ConvertForallToGenericNestGPUPass final + : public impl::ConvertForallToGenericNestGPUPassBase< + ConvertForallToGenericNestGPUPass> { + using Base::Base; + + void runOnOperation() override { + if (failed(runConversion(getOperation()))) { + return signalPassFailure(); + } + } +}; + +} // namespace +} // namespace mlir::iree_compiler::IREE::GPU diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td index 774dc1754142..ec7b02bffc0d 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.td @@ -71,4 +71,21 @@ def VectorizeIREEGPUOpsPass : ]; } +def ConvertForallToGenericNestGPUPass : + InterfacePass<"iree-gpu-convert-forall-to-generic-nest", + "mlir::FunctionOpInterface"> { + let summary = "Converts scf.forall ops with GPU mapping to pcf.generic"; + let description = [{ + Converts scf.forall ops with gpu.thread mapping to nested pcf.generic ops + using subgroup (outer) and lane (inner) scopes. + }]; + let dependentDialects = [ + "::mlir::iree_compiler::IREE::PCF::PCFDialect", + "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect", + "::mlir::arith::ArithDialect", + "::mlir::affine::AffineDialect", + "::mlir::scf::SCFDialect", + ]; +} + #endif // IREE_CODEGEN_DIALECt_GPU_TRANSFORMS_PASSES diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel index 46574c8f7180..bf47b581da89 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/BUILD.bazel @@ -20,6 +20,7 @@ iree_lit_test_suite( # keep sorted [ "combine_barrier_regions.mlir", + "convert_forall_to_generic_nest_gpu.mlir", "distribute_inner_tiled_to_lanes.mlir", "expand_undistributed_inner_tiles.mlir", ], diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt index 432df0ee0e15..14ebdeef859c 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/CMakeLists.txt @@ -15,6 +15,7 @@ iree_lit_test_suite( lit SRCS "combine_barrier_regions.mlir" + "convert_forall_to_generic_nest_gpu.mlir" "distribute_inner_tiled_to_lanes.mlir" "expand_undistributed_inner_tiles.mlir" TOOLS diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/convert_forall_to_generic_nest_gpu.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/convert_forall_to_generic_nest_gpu.mlir new file mode 100644 index 000000000000..b4cc529ed7c4 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/convert_forall_to_generic_nest_gpu.mlir @@ -0,0 +1,62 @@ +// RUN: iree-opt %s --pass-pipeline="builtin.module(func.func(iree-gpu-convert-forall-to-generic-nest))" --split-input-file | FileCheck %s + +// Test converting scf.forall with gpu.thread mapping to nested pcf.generic +// with subgroup scope (outer) and lane scope (inner). + +func.func @test_1d_thread_mapping(%init: tensor<64xf32>) -> tensor<64xf32> { + %result = scf.forall (%i) in (64) shared_outs(%out = %init) -> tensor<64xf32> { + %slice = tensor.extract_slice %out[%i] [1] [1] : tensor<64xf32> to tensor<1xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %slice into %out[%i] [1] [1] : tensor<1xf32> into tensor<64xf32> + } + } {mapping = [#gpu.thread]} + return %result : tensor<64xf32> +} + +// CHECK-LABEL: func.func @test_1d_thread_mapping +// CHECK-SAME: %[[INIT:.+]]: tensor<64xf32> +// CHECK: %[[RESULT:.+]] = pcf.generic +// CHECK-SAME: scope(#iree_gpu.subgroup_scope) +// CHECK: execute(%[[REF:.+]] = %[[INIT]])[%[[SUBGROUP_ID:.+]]: index, %[[NUM_SUBGROUPS:.+]]: index] +// CHECK: pcf.generic +// CHECK-SAME: scope(#iree_gpu.lane_scope) +// CHECK: execute[%[[LANE_ID:.+]]: index, %[[SUBGROUP_SIZE:.+]]: index] +// CHECK: %[[LIN_ID:.+]] = affine.linearize_index [%[[SUBGROUP_ID]], %[[LANE_ID]]] by (%[[NUM_SUBGROUPS]], %[[SUBGROUP_SIZE]]) +// CHECK: %[[TOTAL_COUNT:.+]] = arith.muli %[[NUM_SUBGROUPS]], %[[SUBGROUP_SIZE]] +// CHECK: %[[TILE_SIZE:.+]] = arith.ceildivui %{{.+}}, %[[TOTAL_COUNT]] +// CHECK: %[[START:.+]] = arith.muli %[[LIN_ID]], %[[TILE_SIZE]] +// CHECK: %[[END_UNCLAMPED:.+]] = arith.addi %[[START]], %[[TILE_SIZE]] +// CHECK: %[[END:.+]] = arith.minui %[[END_UNCLAMPED]] +// CHECK: scf.forall (%[[IV:.+]]) = (%[[START]]) to (%[[END]]) +// CHECK: pcf.write_slice %{{.+}} into %[[REF]][%[[IV]]] +// CHECK: pcf.return +// CHECK: pcf.return +// CHECK: return %[[RESULT]] + +// ----- + +func.func @test_2d_thread_mapping(%init: tensor<64x128xf32>) -> tensor<64x128xf32> { + %result = scf.forall (%i, %j) in (64, 128) shared_outs(%out = %init) -> tensor<64x128xf32> { + %slice = tensor.extract_slice %out[%i, %j] [1, 1] [1, 1] : tensor<64x128xf32> to tensor<1x1xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %slice into %out[%i, %j] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<64x128xf32> + } + } {mapping = [#gpu.thread, #gpu.thread]} + return %result : tensor<64x128xf32> +} + +// CHECK-LABEL: func.func @test_2d_thread_mapping +// CHECK-SAME: %[[INIT:.+]]: tensor<64x128xf32> +// CHECK: pcf.generic +// CHECK-SAME: scope(#iree_gpu.subgroup_scope) +// CHECK: execute(%[[REF:.+]] = %[[INIT]])[%[[SUBGROUP_ID:.+]]: index, %[[NUM_SUBGROUPS:.+]]: index] +// CHECK: pcf.generic +// CHECK-SAME: scope(#iree_gpu.lane_scope) +// CHECK: execute[%[[LANE_ID:.+]]: index, %[[SUBGROUP_SIZE:.+]]: index] +// CHECK: %[[LIN_ID:.+]] = affine.linearize_index [%[[SUBGROUP_ID]], %[[LANE_ID]]] by (%[[NUM_SUBGROUPS]], %[[SUBGROUP_SIZE]]) +// CHECK: %[[TOTAL_COUNT:.+]] = arith.muli %[[NUM_SUBGROUPS]], %[[SUBGROUP_SIZE]] +// CHECK: scf.forall (%[[IV:.+]]) = +// CHECK: %[[INDICES:.+]]:2 = affine.delinearize_index %[[IV]] into (64, 128) +// CHECK: pcf.write_slice %{{.+}} into %[[REF]][%[[INDICES]]#0, %[[INDICES]]#1] +// CHECK: pcf.return +// CHECK: pcf.return diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/PCFBase.td b/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/PCFBase.td index 245030681a4a..f3bcdf038b09 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/PCFBase.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/PCFBase.td @@ -207,6 +207,10 @@ def PCF_TestScopeAttr assert(false && "why are you here?"); return {}; } + int64_t getNativeNumProcessorIds() { + assert(false && "why are you here?"); + return 1; + } }]; } @@ -234,6 +238,9 @@ def PCF_SequentialAttr SmallVector ids(numIds, zero); return ids; } + int64_t getNativeNumProcessorIds() { + return 1; + } }]; } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/PCFInterfaces.td b/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/PCFInterfaces.td index ba91e1aa0dfe..8d5fa8dcb2c5 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/PCFInterfaces.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/IR/PCFInterfaces.td @@ -115,6 +115,20 @@ def PCF_ScopeAttrInterface : AttrInterface<"ScopeAttrInterface"> { return b.getI64IntegerAttr(16); }] >, + InterfaceMethod< + /*desc=*/[{ + Returns the number of physical processor IDs this scope natively + supports. Iterating over fewer than the native number of IDs results + in potentially duplicated execution over the unaccounted for + dimension(s). + + For example, a workgroup scope typically returns 3 (x, y, z), while + a sequential scope returns 1. + }], + /*retTy=*/"int64_t", + /*methodName=*/"getNativeNumProcessorIds", + /*args=*/(ins) + >, ]; } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/BUILD.bazel index f307c8bf4d97..ed2a5704c496 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/BUILD.bazel @@ -35,7 +35,7 @@ iree_gentbl_cc_library( iree_compiler_cc_library( name = "Transforms", srcs = [ - "ConvertForallToLoops.cpp", + "ConvertForallToPCF.cpp", "ConvertSRefToMemRef.cpp", "FuseConsumers.cpp", "FusePCFWrites.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/CMakeLists.txt index 611acc14b49a..d7c798b2fe41 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/CMakeLists.txt @@ -28,7 +28,7 @@ iree_cc_library( "Passes.h.inc" "Transforms.h" SRCS - "ConvertForallToLoops.cpp" + "ConvertForallToPCF.cpp" "ConvertSRefToMemRef.cpp" "FuseConsumers.cpp" "FusePCFWrites.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/ConvertForallToLoops.cpp b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/ConvertForallToLoops.cpp deleted file mode 100644 index b260f3efdac5..000000000000 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/ConvertForallToLoops.cpp +++ /dev/null @@ -1,312 +0,0 @@ -// Copyright 2025 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" -#include "iree/compiler/Codegen/Dialect/PCF/IR/PCF.h" -#include "iree/compiler/Codegen/Dialect/PCF/IR/PCFAttrs.h" -#include "iree/compiler/Codegen/Dialect/PCF/IR/PCFOps.h" -#include "iree/compiler/Codegen/Dialect/PCF/IR/PCFTypes.h" -#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.h" -#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h" -#include "iree/compiler/Utils/RewriteUtils.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVectorExtras.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" - -#define DEBUG_TYPE "iree-pcf-convert-forall-to-loops" - -namespace mlir::iree_compiler::IREE::PCF { - -#define GEN_PASS_DEF_CONVERTFORALLTOLOOPSPASS -#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.h.inc" - -namespace { - -/// Returns true if the forall op has LocalMappingAttr mapping attributes, -/// or the mapping is empty/not present. -static bool hasEmptyOrLocalMapping(scf::ForallOp forallOp) { - std::optional mapping = forallOp.getMapping(); - if (!mapping || mapping->empty()) { - return true; - } - return llvm::all_of(mapping.value(), - llvm::IsaPred); -} - -struct ConvertForallToLoopsPass final - : impl::ConvertForallToLoopsPassBase { - void runOnOperation() override; -}; - -void ConvertForallToLoopsPass::runOnOperation() { - SmallVector opsToConvert; - getOperation()->walk([&](scf::ForallOp forallOp) { - // Empty mapping, no mapping, and local mapping all map to `pcf.sequential`. - // If it is a local mapping, then the lowering pattern will automatically - // handle any mapping permutation based on the mapping attribute's relative - // id. - if (hasEmptyOrLocalMapping(forallOp)) { - opsToConvert.push_back(forallOp); - } - }); - - IRRewriter rewriter(getOperation()); - PCF::ScopeAttrInterface sequentialScope = - PCF::SequentialAttr::get(&getContext()); - for (auto forallOp : opsToConvert) { - rewriter.setInsertionPoint(forallOp); - if (failed(convertForallToPCF(rewriter, forallOp, sequentialScope))) { - forallOp->emitOpError("failed to convert forall"); - return signalPassFailure(); - } - } -} - -static SmallVector getMappingPermutation(ArrayAttr mapping) { - auto mappingRange = mapping.getAsRange(); - int64_t mappingBase = - cast( - *std::min_element(mappingRange.begin(), mappingRange.end(), - [](auto a, auto b) { - return a.getMappingId() < b.getMappingId(); - })) - .getMappingId(); - return llvm::map_to_vector( - mappingRange, [&](auto a) { return a.getMappingId() - mappingBase; }); -} - -static FailureOr> -getProcessorIdPermutation(scf::ForallOp forallOp) { - std::optional mappingAttr = forallOp.getMapping(); - // Unspecified mappings indicate sequential foralls which we can choose the - // iteration order for. - if (!mappingAttr) { - return llvm::to_vector(llvm::reverse(llvm::seq(forallOp.getRank()))); - } - // Empty mappings are unsupported at the moment. It's unclear when a forall - // with an empty mapping would be useful or important. - if (mappingAttr.value().empty()) { - return {}; - } - - SmallVector perm = getMappingPermutation(mappingAttr.value()); - if (!isPermutationVector(perm)) { - return failure(); - } - return perm; -} - -static LogicalResult matchForallConversion(scf::ForallOp forallOp) { - scf::InParallelOp terminator = forallOp.getTerminator(); - for (Operation &op : terminator.getBody()->getOperations()) { - // Bail on terminator ops other than parallel insert slice since we don't - // know how to convert it. - auto insertSliceOp = dyn_cast(&op); - if (!insertSliceOp) { - return failure(); - } - - // Bail on non-shared outs destinations. - auto bbArgDest = dyn_cast(insertSliceOp.getDest()); - if (!bbArgDest || bbArgDest.getOwner()->getParentOp() != forallOp) { - return failure(); - } - } - - for (BlockArgument bbArg : forallOp.getRegionIterArgs()) { - for (OpOperand &use : bbArg.getUses()) { - // Skip users outside of the terminator. These are replaced with the init. - if (use.getOwner()->getParentOp() != terminator) { - continue; - } - - // Bail if the use is not on the dest of the insert slice. - auto insertSliceUser = - cast(use.getOwner()); - if (use != insertSliceUser.getDestMutable()) { - return failure(); - } - } - } - if (failed(getProcessorIdPermutation(forallOp))) { - return failure(); - } - return success(); -} - -static PCF::LoopOp convertForallToPCFImpl(RewriterBase &rewriter, - scf::ForallOp forallOp, - PCF::ScopeAttrInterface scope, - int64_t numIds) { - assert(succeeded(matchForallConversion(forallOp)) && - "converting unsupported forall op"); - - // Maps from fastest -> slowest to current order. - SmallVector perm = *getProcessorIdPermutation(forallOp); - - // Maps from current order to fastest -> slowest. - SmallVector invPerm = invertPermutationVector(perm); - - // Get the permuted ubs/lbs/steps and save them for later since we need them - // to reconstruct the correct ids. - SmallVector mixedUbs = forallOp.getMixedUpperBound(); - applyPermutationToVector(mixedUbs, invPerm); - SmallVector mixedLbs = forallOp.getMixedLowerBound(); - applyPermutationToVector(mixedLbs, invPerm); - SmallVector mixedStep = forallOp.getMixedStep(); - applyPermutationToVector(mixedStep, invPerm); - // Permute the ivs of the body to match the original mapping order by - // permuting the uses before we move it over to the new op. - permuteValues(rewriter, forallOp.getLoc(), forallOp.getInductionVars(), perm); - - scf::InParallelOp terminator = forallOp.getTerminator(); - MutableArrayRef bodySharedOuts = forallOp.getRegionIterArgs(); - - // Replace non-insert slice users outside of the `scf.forall.in_parallel` with - // the init values. - ValueRange inits = forallOp.getDpsInits(); - for (auto [init, bbArg] : llvm::zip_equal(inits, bodySharedOuts)) { - rewriter.replaceUsesWithIf(bbArg, init, [&](OpOperand &use) { - return use.getOwner()->getParentOp() != terminator; - }); - } - - Location loc = forallOp.getLoc(); - - AffineExpr s0, s1, s2; - bindSymbols(rewriter.getContext(), s0, s1, s2); - AffineExpr numIters = (s0 - s1).ceilDiv(s2); - SmallVector iterationCounts; - for (auto [ub, lb, step] : llvm::zip_equal(mixedUbs, mixedLbs, mixedStep)) { - OpFoldResult iterCount = affine::makeComposedFoldedAffineApply( - rewriter, loc, numIters, ArrayRef{ub, lb, step}); - iterationCounts.push_back( - getValueOrCreateConstantIndexOp(rewriter, loc, iterCount)); - } - - int64_t numDelinIds = numIds > 0 && iterationCounts.size() > numIds - ? iterationCounts.size() - numIds + 1 - : 0; - // Reverse the delinearization basis because affine.delinearize_index is from - // slowest to fastest varying. - SmallVector delinearizationBasis( - llvm::reverse(ArrayRef(iterationCounts).take_back(numDelinIds))); - if (!delinearizationBasis.empty()) { - AffineExpr mul = s0 * s1; - Value total = delinearizationBasis.front(); - total = std::accumulate( - delinearizationBasis.begin() + 1, delinearizationBasis.end(), total, - [&](Value l, Value r) { - OpFoldResult acc = - affine::makeComposedFoldedAffineApply(rewriter, loc, mul, {l, r}); - return getValueOrCreateConstantIndexOp(rewriter, loc, acc); - }); - // Replace the first |numDelinIds| entries with their product. - iterationCounts.erase(iterationCounts.end() - numDelinIds, - iterationCounts.end()); - iterationCounts.push_back(total); - } - - auto loopOp = PCF::LoopOp::create(rewriter, loc, scope, iterationCounts, - forallOp.getDpsInits()); - SmallVector argReplacements; - { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPointToStart(loopOp.getBody()); - // If/else here to avoid popping off the front of a vector. - if (!delinearizationBasis.empty()) { - auto delin = affine::AffineDelinearizeIndexOp::create( - rewriter, loc, loopOp.getIdArgs().back(), delinearizationBasis, - /*hasOuterBound=*/true); - argReplacements.append(loopOp.getIdArgs().begin(), - loopOp.getIdArgs().end() - 1); - // Add replacements in reverse to get fastest -> slowest. - auto delinResultsReverse = llvm::reverse(delin.getResults()); - argReplacements.append(delinResultsReverse.begin(), - delinResultsReverse.end()); - } else { - argReplacements.append(loopOp.getIdArgs().begin(), - loopOp.getIdArgs().end()); - } - - // id * step + lb. - AffineExpr applyLbAndStep = s0 * s1 + s2; - for (auto [id, step, lb] : - llvm::zip_equal(argReplacements, mixedStep, mixedLbs)) { - OpFoldResult newId = affine::makeComposedFoldedAffineApply( - rewriter, loc, applyLbAndStep, {id, step, lb}); - id = getValueOrCreateConstantIndexOp(rewriter, loc, newId); - } - } - - // Add parent only sync scope to the body arg types. - Attribute syncScope = PCF::SyncOnReturnAttr::get(rewriter.getContext()); - for (auto regionRefArg : loopOp.getRegionRefArgs()) { - auto srefType = cast(regionRefArg.getType()); - auto newSrefType = PCF::ShapedRefType::get( - rewriter.getContext(), srefType.getShape(), srefType.getElementType(), - srefType.getScope(), syncScope); - regionRefArg.setType(newSrefType); - } - - rewriter.setInsertionPoint(terminator); - llvm::SmallDenseMap argToReplacementMap; - for (auto [bbArg, refArg] : - llvm::zip_equal(bodySharedOuts, loopOp.getRegionRefArgs())) { - argToReplacementMap[bbArg] = refArg; - } - - // Iterate the insert_slice ops in the order to retain the order of writes. - SmallVector insertOps( - terminator.getBody()->getOps()); - for (tensor::ParallelInsertSliceOp insertSliceOp : insertOps) { - PCF::WriteSliceOp::create( - rewriter, insertSliceOp.getLoc(), insertSliceOp.getSource(), - argToReplacementMap[insertSliceOp.getDest()], - insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), - insertSliceOp.getMixedStrides()); - rewriter.eraseOp(insertSliceOp); - } - - // Replace the terminator with the new terminator kind. - rewriter.replaceOpWithNewOp(terminator); - - // Use the inits as the replacements for the shared outs bbargs to appease - // `inlineBlockBefore`. By this point all of their users have been replaced - // or erased so it doesn't matter what goes here. - argReplacements.append(inits.begin(), inits.end()); - rewriter.inlineBlockBefore(forallOp.getBody(), loopOp.getBody(), - loopOp.getBody()->end(), argReplacements); - - rewriter.replaceOp(forallOp, loopOp); - return loopOp; -} - -} // namespace - -//===---------------------------------------------------------------------===// -// scf.forall -> pcf.loop -//===---------------------------------------------------------------------===// - -FailureOr convertForallToPCF(RewriterBase &rewriter, - scf::ForallOp forallOp, - PCF::ScopeAttrInterface scope, - int64_t numIds) { - if (failed(matchForallConversion(forallOp))) { - return failure(); - } - return convertForallToPCFImpl(rewriter, forallOp, scope, numIds); -} - -} // namespace mlir::iree_compiler::IREE::PCF diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/ConvertForallToPCF.cpp b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/ConvertForallToPCF.cpp new file mode 100644 index 000000000000..e0451fd464af --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/ConvertForallToPCF.cpp @@ -0,0 +1,662 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h" +#include "iree/compiler/Codegen/Dialect/PCF/IR/PCF.h" +#include "iree/compiler/Codegen/Dialect/PCF/IR/PCFAttrs.h" +#include "iree/compiler/Codegen/Dialect/PCF/IR/PCFOps.h" +#include "iree/compiler/Codegen/Dialect/PCF/IR/PCFTypes.h" +#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.h" +#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h" +#include "iree/compiler/Utils/RewriteUtils.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/IRMapping.h" + +#define DEBUG_TYPE "iree-pcf-convert-forall-to-pcf" + +namespace mlir::iree_compiler::IREE::PCF { + +#define GEN_PASS_DEF_TESTCONVERTFORALLTOLOOPSPASS +#define GEN_PASS_DEF_TESTCONVERTFORALLTOGENERICNESTPASS +#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// Shared Utilities +//===----------------------------------------------------------------------===// + +/// Returns true if the forall op has LocalMappingAttr mapping attributes, +/// or the mapping is empty/not present. +static bool hasEmptyOrLocalMapping(scf::ForallOp forallOp) { + std::optional mapping = forallOp.getMapping(); + if (!mapping || mapping->empty()) { + return true; + } + return llvm::all_of(mapping.value(), + llvm::IsaPred); +} + +/// Returns the permutation from mapping attributes based on their relative +/// processor IDs. Lower IDs indicate faster varying dimensions. The returned +/// permutation maps from the fastest-to-slowest order back to the original +/// forall dimension order. +/// +/// Example: mapping = [local_mapping<1>, local_mapping<0>] +/// - local_mapping<0> has lower ID than local_mapping<1> +/// - So dimension 1 (id 0) is faster, dimension 0 (id 1) is slower +/// - Returns [1, 0]: position 0 in linearized order maps to dim 1, etc. +static SmallVector getMappingPermutation(ArrayAttr mapping) { + auto mappingRange = mapping.getAsRange(); + int64_t mappingBase = + cast( + *std::min_element(mappingRange.begin(), mappingRange.end(), + [](auto a, auto b) { + return a.getMappingId() < b.getMappingId(); + })) + .getMappingId(); + return llvm::map_to_vector( + mappingRange, [&](auto a) { return a.getMappingId() - mappingBase; }); +} + +/// Returns the permutation for processor ID ordering from the forall's mapping. +/// For unspecified or empty mappings, returns a reversed sequence (assumes +/// natural fastest-to-slowest is reverse of dimension order). +static FailureOr> +getProcessorIdPermutation(scf::ForallOp forallOp) { + std::optional mappingAttr = forallOp.getMapping(); + // Unspecified mappings indicate sequential foralls which we can choose the + // iteration order for. + if (!mappingAttr) { + return llvm::to_vector(llvm::reverse(llvm::seq(forallOp.getRank()))); + } + // Empty mappings are unsupported at the moment. It's unclear when a forall + // with an empty mapping would be useful or important. + if (mappingAttr.value().empty()) { + return SmallVector{}; + } + + SmallVector perm = getMappingPermutation(mappingAttr.value()); + if (!isPermutationVector(perm)) { + return failure(); + } + return perm; +} + +/// Validates that the forall op can be converted. +static LogicalResult matchForallConversion(scf::ForallOp forallOp) { + scf::InParallelOp terminator = forallOp.getTerminator(); + for (Operation &op : terminator.getBody()->getOperations()) { + // Bail on terminator ops other than parallel insert slice since we don't + // know how to convert it. + auto insertSliceOp = dyn_cast(&op); + if (!insertSliceOp) { + return failure(); + } + + // Bail on non-shared outs destinations. + auto bbArgDest = dyn_cast(insertSliceOp.getDest()); + if (!bbArgDest || bbArgDest.getOwner()->getParentOp() != forallOp) { + return failure(); + } + } + + for (BlockArgument bbArg : forallOp.getRegionIterArgs()) { + for (OpOperand &use : bbArg.getUses()) { + // Skip users outside of the terminator. These are replaced with the init. + if (use.getOwner()->getParentOp() != terminator) { + continue; + } + + // Bail if the use is not on the dest of the insert slice. + auto insertSliceUser = + cast(use.getOwner()); + if (use != insertSliceUser.getDestMutable()) { + return failure(); + } + } + } + // Validate that the mapping permutation is valid. + if (failed(getProcessorIdPermutation(forallOp))) { + return failure(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// scf.forall -> pcf.loop Implementation +//===----------------------------------------------------------------------===// + +static PCF::LoopOp convertForallToPCFLoopImpl(RewriterBase &rewriter, + scf::ForallOp forallOp, + PCF::ScopeAttrInterface scope, + int64_t numIds) { + assert(succeeded(matchForallConversion(forallOp)) && + "converting unsupported forall op"); + + // Maps from fastest -> slowest to current order. + SmallVector perm = *getProcessorIdPermutation(forallOp); + + // Maps from current order to fastest -> slowest. + SmallVector invPerm = invertPermutationVector(perm); + + // Get the permuted ubs/lbs/steps and save them for later since we need them + // to reconstruct the correct ids. + SmallVector mixedUbs = forallOp.getMixedUpperBound(); + applyPermutationToVector(mixedUbs, invPerm); + SmallVector mixedLbs = forallOp.getMixedLowerBound(); + applyPermutationToVector(mixedLbs, invPerm); + SmallVector mixedStep = forallOp.getMixedStep(); + applyPermutationToVector(mixedStep, invPerm); + // Permute the ivs of the body to match the original mapping order by + // permuting the uses before we move it over to the new op. + permuteValues(rewriter, forallOp.getLoc(), forallOp.getInductionVars(), perm); + + scf::InParallelOp terminator = forallOp.getTerminator(); + MutableArrayRef bodySharedOuts = forallOp.getRegionIterArgs(); + + // Replace non-insert slice users outside of the `scf.forall.in_parallel` with + // the init values. + ValueRange inits = forallOp.getDpsInits(); + for (auto [init, bbArg] : llvm::zip_equal(inits, bodySharedOuts)) { + rewriter.replaceUsesWithIf(bbArg, init, [&](OpOperand &use) { + return use.getOwner()->getParentOp() != terminator; + }); + } + + Location loc = forallOp.getLoc(); + + AffineExpr s0, s1, s2; + bindSymbols(rewriter.getContext(), s0, s1, s2); + AffineExpr numIters = (s0 - s1).ceilDiv(s2); + SmallVector iterationCounts; + for (auto [ub, lb, step] : llvm::zip_equal(mixedUbs, mixedLbs, mixedStep)) { + OpFoldResult iterCount = affine::makeComposedFoldedAffineApply( + rewriter, loc, numIters, ArrayRef{ub, lb, step}); + iterationCounts.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, iterCount)); + } + + int64_t numDelinIds = numIds > 0 && iterationCounts.size() > numIds + ? iterationCounts.size() - numIds + 1 + : 0; + // Reverse the delinearization basis because affine.delinearize_index is from + // slowest to fastest varying. + SmallVector delinearizationBasis( + llvm::reverse(ArrayRef(iterationCounts).take_back(numDelinIds))); + if (!delinearizationBasis.empty()) { + AffineExpr mul = s0 * s1; + Value total = delinearizationBasis.front(); + total = std::accumulate( + delinearizationBasis.begin() + 1, delinearizationBasis.end(), total, + [&](Value l, Value r) { + OpFoldResult acc = + affine::makeComposedFoldedAffineApply(rewriter, loc, mul, {l, r}); + return getValueOrCreateConstantIndexOp(rewriter, loc, acc); + }); + // Replace the first |numDelinIds| entries with their product. + iterationCounts.erase(iterationCounts.end() - numDelinIds, + iterationCounts.end()); + iterationCounts.push_back(total); + } + + auto loopOp = PCF::LoopOp::create(rewriter, loc, scope, iterationCounts, + forallOp.getDpsInits()); + SmallVector argReplacements; + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointToStart(loopOp.getBody()); + // If/else here to avoid popping off the front of a vector. + if (!delinearizationBasis.empty()) { + auto delin = affine::AffineDelinearizeIndexOp::create( + rewriter, loc, loopOp.getIdArgs().back(), delinearizationBasis, + /*hasOuterBound=*/true); + argReplacements.append(loopOp.getIdArgs().begin(), + loopOp.getIdArgs().end() - 1); + // Add replacements in reverse to get fastest -> slowest. + auto delinResultsReverse = llvm::reverse(delin.getResults()); + argReplacements.append(delinResultsReverse.begin(), + delinResultsReverse.end()); + } else { + argReplacements.append(loopOp.getIdArgs().begin(), + loopOp.getIdArgs().end()); + } + + // id * step + lb. + AffineExpr applyLbAndStep = s0 * s1 + s2; + for (auto [id, step, lb] : + llvm::zip_equal(argReplacements, mixedStep, mixedLbs)) { + OpFoldResult newId = affine::makeComposedFoldedAffineApply( + rewriter, loc, applyLbAndStep, {id, step, lb}); + id = getValueOrCreateConstantIndexOp(rewriter, loc, newId); + } + } + + // Add parent only sync scope to the body arg types. + Attribute syncScope = PCF::SyncOnReturnAttr::get(rewriter.getContext()); + for (auto regionRefArg : loopOp.getRegionRefArgs()) { + auto srefType = cast(regionRefArg.getType()); + auto newSrefType = PCF::ShapedRefType::get( + rewriter.getContext(), srefType.getShape(), srefType.getElementType(), + srefType.getScope(), syncScope); + regionRefArg.setType(newSrefType); + } + + rewriter.setInsertionPoint(terminator); + llvm::SmallDenseMap argToReplacementMap; + for (auto [bbArg, refArg] : + llvm::zip_equal(bodySharedOuts, loopOp.getRegionRefArgs())) { + argToReplacementMap[bbArg] = refArg; + } + + // Iterate the insert_slice ops in the order to retain the order of writes. + SmallVector insertOps( + terminator.getBody()->getOps()); + for (tensor::ParallelInsertSliceOp insertSliceOp : insertOps) { + PCF::WriteSliceOp::create( + rewriter, insertSliceOp.getLoc(), insertSliceOp.getSource(), + argToReplacementMap[insertSliceOp.getDest()], + insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), + insertSliceOp.getMixedStrides()); + rewriter.eraseOp(insertSliceOp); + } + + // Replace the terminator with the new terminator kind. + rewriter.replaceOpWithNewOp(terminator); + + // Use the inits as the replacements for the shared outs bbargs to appease + // `inlineBlockBefore`. By this point all of their users have been replaced + // or erased so it doesn't matter what goes here. + argReplacements.append(inits.begin(), inits.end()); + rewriter.inlineBlockBefore(forallOp.getBody(), loopOp.getBody(), + loopOp.getBody()->end(), argReplacements); + + rewriter.replaceOp(forallOp, loopOp); + return loopOp; +} + +//===----------------------------------------------------------------------===// +// scf.forall -> pcf.generic nest Implementation +//===----------------------------------------------------------------------===// + +static PCF::GenericOp +convertForallToGenericNestImpl(RewriterBase &rewriter, scf::ForallOp forallOp, + ArrayRef scopes) { + assert(succeeded(matchForallConversion(forallOp)) && + "converting unsupported forall op"); + assert(!scopes.empty() && "at least one scope required"); + + Location loc = forallOp.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + + // Get forall outputs to base the tied results on. + ValueRange outputs = forallOp.getOutputs(); + TypeRange resultTypes = forallOp.getResultTypes(); + scf::InParallelOp terminator = forallOp.getTerminator(); + MutableArrayRef bodySharedOuts = forallOp.getRegionIterArgs(); + + // Replace non-insert slice users outside of the `scf.forall.in_parallel` with + // the init values. + ValueRange inits = forallOp.getDpsInits(); + for (auto [init, bbArg] : llvm::zip_equal(inits, bodySharedOuts)) { + rewriter.replaceUsesWithIf(bbArg, init, [&](OpOperand &use) { + return use.getOwner()->getParentOp() != terminator; + }); + } + + // Create nested pcf.generic ops for each scope. + SmallVector isTied(outputs.size(), true); + Attribute syncScope = PCF::SyncOnReturnAttr::get(ctx); + + SmallVector allIds, allCounts; + + // Create outermost generic with inits and results. Use the scope's native + // number of processor IDs to determine iteration dimensions. + int64_t outermostNumIds = scopes[0].getNativeNumProcessorIds(); + PCF::GenericOp outermostGeneric = PCF::GenericOp::create( + rewriter, loc, resultTypes, scopes[0], outputs, /*dynamicSizes=*/{}, + isTied, outermostNumIds, /*syncOnReturn=*/false); + + // Add sync scope to sref types for outermost generic. + for (BlockArgument regionRefArg : outermostGeneric.getRegionRefArgs()) { + auto srefType = cast(regionRefArg.getType()); + auto newSrefType = PCF::ShapedRefType::get(ctx, srefType.getShape(), + srefType.getElementType(), + srefType.getScope(), syncScope); + regionRefArg.setType(newSrefType); + } + + // Collect id/count args from outermost generic. + allIds.append(outermostGeneric.getIdArgs().begin(), + outermostGeneric.getIdArgs().end()); + allCounts.append(outermostGeneric.getCountArgs().begin(), + outermostGeneric.getCountArgs().end()); + + // Set insertion point to end of execute block for next generic or body. + Block &outermostBlock = outermostGeneric.getRegion().front(); + rewriter.setInsertionPointToEnd(&outermostBlock); + + // Create inner generics (no inits, no results). + for (size_t i = 1; i < scopes.size(); ++i) { + PCF::ScopeAttrInterface scope = scopes[i]; + int64_t numIds = scope.getNativeNumProcessorIds(); + PCF::GenericOp generic = PCF::GenericOp::create( + rewriter, loc, TypeRange{}, scope, ValueRange{}, /*dynamicSizes=*/{}, + SmallVector{}, numIds, /*syncOnReturn=*/false); + + // Same here, collect id/count args from this generic and update the + // insertion point for the next one. + allIds.append(generic.getIdArgs().begin(), generic.getIdArgs().end()); + allCounts.append(generic.getCountArgs().begin(), + generic.getCountArgs().end()); + + Block &executeBlock = generic.getRegion().front(); + rewriter.setInsertionPointToEnd(&executeBlock); + } + + // In innermost block, linearize all ids if there are multiple. + // linearId = id[0] * count[1] * ... * count[n-1] + id[1] * count[2] * ... + + // ... + id[n-1] totalWorkers = count[0] * count[1] * ... * count[n-1] + Value linearId; + if (allIds.size() == 1) { + // Shortcut single id to avoid creating IR. + linearId = allIds[0]; + } else { + // Use affine.linearize_index with the |counts| as the linearization basis. + SmallVector countOfrs = llvm::map_to_vector( + allCounts, [](Value v) -> OpFoldResult { return v; }); + linearId = + affine::AffineLinearizeIndexOp::create(rewriter, loc, allIds, countOfrs, + /*disjoint=*/false); + } + + // Compute total workers as product of all counts. + Value totalWorkers = allCounts[0]; + for (size_t i = 1; i < allCounts.size(); ++i) { + totalWorkers = + arith::MulIOp::create(rewriter, loc, totalWorkers, allCounts[i]); + } + + // Compute total iteration count from forall bounds. + SmallVector mixedUbs = forallOp.getMixedUpperBound(); + Value totalIters = arith::ConstantIndexOp::create(rewriter, loc, 1); + for (OpFoldResult ub : mixedUbs) { + Value dim = getValueOrCreateConstantIndexOp(rewriter, loc, ub); + totalIters = arith::MulIOp::create(rewriter, loc, totalIters, dim); + } + + // Compute chunk bounds: chunkSize = ceildiv(total, totalWorkers). We'll + // create a spillover scf.forall that iterates from the start of the current + // chunk (call it start) to min(start + chunkSize, total). This greedily + // allocates an equal number of iterations to all workers except the last. + Value chunkSize = + arith::CeilDivUIOp::create(rewriter, loc, totalIters, totalWorkers); + Value linearLb = arith::MulIOp::create(rewriter, loc, linearId, chunkSize); + Value linearUbRaw = arith::AddIOp::create(rewriter, loc, linearLb, chunkSize); + Value linearUb = + arith::MinUIOp::create(rewriter, loc, linearUbRaw, totalIters); + + // Create the inner scf.forall with linearized bounds (no body builder). + SmallVector lbs = {linearLb}; + SmallVector ubs = {linearUb}; + SmallVector steps = {rewriter.getIndexAttr(1)}; + + auto innerForall = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps, + /*outputs=*/ValueRange{}, + /*mapping=*/std::nullopt); + + // The scf.forall without body builder creates an empty block with a + // terminator. Remove the terminator so we can populate the body. + rewriter.eraseOp(innerForall.getBody()->getTerminator()); + + // Get the permutation from mapping. The permutation value at position i + // indicates the "speed rank" of dimension i (lower = faster varying). + // For example, [#iree_codegen.local_mapping<1>, + // #iree_codegen.local_mapping<0>] gives perm = [1, 0]: + // - dim 0 has rank 1 (slower) + // - dim 1 has rank 0 (faster) + SmallVector perm = *getProcessorIdPermutation(forallOp); + + // Compute bounds ordered from slowest to fastest for delinearization. + // affine.delinearize_index expects bounds from slowest to fastest varying. + // Sort dimensions by decreasing speed rank (slowest first). + SmallVector> rankAndDim; + for (size_t i = 0; i < perm.size(); ++i) { + rankAndDim.emplace_back(perm[i], i); + } + // Sort by rank descending (highest rank = slowest first). + llvm::sort(rankAndDim, [](auto &a, auto &b) { return a.first > b.first; }); + + // Build permuted upper bounds (slowest to fastest). + SmallVector permutedUbs; + SmallVector delinToOrigDim; // Maps delinearized result index to + // original forall dim. + for (auto [rank, dim] : rankAndDim) { + permutedUbs.push_back(mixedUbs[dim]); + delinToOrigDim.push_back(dim); + } + + // Map old induction variables to new linearized index. + // For 1D: directly use the induction variable. + // For multi-D: need delinearization with permutation handling. + IRMapping mapping; + if (forallOp.getRank() == 1) { + // Shortcut rank-1 foralls to avoid creating the delinearize_index. + mapping.map(forallOp.getInductionVars()[0], + innerForall.getInductionVars()[0]); + } else { + // For multi-dimensional forall, we need to delinearize. + // Delinearize the iteration index back to multi-D indices. + rewriter.setInsertionPointToStart(innerForall.getBody()); + Value linearIdx = innerForall.getInductionVars()[0]; + auto delinearized = affine::AffineDelinearizeIndexOp::create( + rewriter, loc, linearIdx, permutedUbs, /*hasOuterBound=*/true); + // The delinearized results are in permuted order (slowest to fastest). + // Map them back to the corresponding forall induction variables. + // delinToOrigDim[i] tells us which original forall dim corresponds to + // delinearized result i. + for (size_t i = 0; i < delinToOrigDim.size(); ++i) { + int64_t origDim = delinToOrigDim[i]; + mapping.map(forallOp.getInductionVars()[origDim], + delinearized.getResult(i)); + } + } + + // Clone body operations (except terminator) into inner forall. + rewriter.setInsertionPointToEnd(innerForall.getBody()); + + for (Operation &op : forallOp.getBody()->without_terminator()) { + rewriter.clone(op, mapping); + } + + // Convert parallel_insert_slice to pcf.write_slice. + llvm::SmallDenseMap bbArgToSref; + for (auto [bbArg, refArg] : + llvm::zip_equal(bodySharedOuts, outermostGeneric.getRegionRefArgs())) { + bbArgToSref[bbArg] = refArg; + } + + for (Operation &op : terminator.getBody()->getOperations()) { + if (auto insertOp = dyn_cast(&op)) { + Value src = mapping.lookupOrDefault(insertOp.getSource()); + Value destRef = bbArgToSref[insertOp.getDest()]; + + // Map offsets, sizes, and strides through the mapping. + SmallVector offsets, sizes, strides; + for (OpFoldResult offset : insertOp.getMixedOffsets()) { + if (auto val = dyn_cast(offset)) { + offsets.push_back(mapping.lookupOrDefault(val)); + } else { + offsets.push_back(offset); + } + } + for (OpFoldResult size : insertOp.getMixedSizes()) { + if (auto val = dyn_cast(size)) { + sizes.push_back(mapping.lookupOrDefault(val)); + } else { + sizes.push_back(size); + } + } + for (OpFoldResult stride : insertOp.getMixedStrides()) { + if (auto val = dyn_cast(stride)) { + strides.push_back(mapping.lookupOrDefault(val)); + } else { + strides.push_back(stride); + } + } + + PCF::WriteSliceOp::create(rewriter, loc, src, destRef, offsets, sizes, + strides); + } + } + + // Add an empty in_parallel terminator to the inner forall. + scf::InParallelOp::create(rewriter, loc); + + // Add pcf.return terminators to all generic blocks, from innermost to + // outermost. We need to walk from innermost out, adding returns. + // The innermost is where we currently are. + Operation *current = innerForall.getOperation(); + while (current) { + Operation *parent = current->getParentOp(); + if (auto generic = dyn_cast_or_null(parent)) { + Block &block = generic.getRegion().front(); + rewriter.setInsertionPointToEnd(&block); + PCF::ReturnOp::create(rewriter, loc); + current = generic.getOperation(); + } else { + // Break on the first non-generic (should be the parent of all the IR we + // just created). + break; + } + } + + return outermostGeneric; +} + +//===----------------------------------------------------------------------===// +// scf.forall -> pcf.loop Pass +//===----------------------------------------------------------------------===// + +struct TestConvertForallToLoopsPass final + : impl::TestConvertForallToLoopsPassBase { + void runOnOperation() override { + SmallVector opsToConvert; + getOperation()->walk([&](scf::ForallOp forallOp) { + // Empty mapping, no mapping, and local mapping all map to + // `pcf.sequential`. If it is a local mapping, then the lowering pattern + // will automatically handle any mapping permutation based on the mapping + // attribute's relative id. + if (hasEmptyOrLocalMapping(forallOp)) { + opsToConvert.push_back(forallOp); + } + }); + + IRRewriter rewriter(getOperation()); + PCF::ScopeAttrInterface sequentialScope = + PCF::SequentialAttr::get(&getContext()); + for (auto forallOp : opsToConvert) { + rewriter.setInsertionPoint(forallOp); + if (failed(convertForallToPCFLoop(rewriter, forallOp, sequentialScope))) { + forallOp->emitOpError("failed to convert forall"); + return signalPassFailure(); + } + } + } +}; + +//===----------------------------------------------------------------------===// +// scf.forall -> pcf.generic nest Pass +//===----------------------------------------------------------------------===// + +struct TestConvertForallToGenericNestPass final + : impl::TestConvertForallToGenericNestPassBase< + TestConvertForallToGenericNestPass> { + using Base::Base; + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + + // Build scope list based on numSequentialScopes. + SmallVector scopeAttrs; + for (int64_t i = 0; i < numSequentialScopes; ++i) { + scopeAttrs.push_back(PCF::SequentialAttr::get(ctx)); + } + + if (scopeAttrs.empty()) { + emitError(getOperation()->getLoc()) << "no scopes specified"; + return signalPassFailure(); + } + + IRRewriter rewriter(ctx); + SmallVector forallOps; + getOperation()->walk([&](scf::ForallOp forallOp) { + // Only convert foralls with empty mapping or local_mapping attributes. + if (hasEmptyOrLocalMapping(forallOp)) { + forallOps.push_back(forallOp); + } + }); + + for (scf::ForallOp forallOp : forallOps) { + rewriter.setInsertionPoint(forallOp); + FailureOr result = + convertForallToGenericNest(rewriter, forallOp, scopeAttrs); + if (failed(result)) { + forallOp.emitError("failed to convert forall to generic nest"); + return signalPassFailure(); + } + // Replace forall results with generic results. + rewriter.replaceOp(forallOp, result->getResults()); + } + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Public API: scf.forall -> pcf.loop +//===----------------------------------------------------------------------===// + +FailureOr convertForallToPCFLoop(RewriterBase &rewriter, + scf::ForallOp forallOp, + PCF::ScopeAttrInterface scope, + int64_t numIds) { + if (failed(matchForallConversion(forallOp))) { + return failure(); + } + return convertForallToPCFLoopImpl(rewriter, forallOp, scope, numIds); +} + +//===----------------------------------------------------------------------===// +// Public API: scf.forall -> pcf.generic nest +//===----------------------------------------------------------------------===// + +FailureOr +convertForallToGenericNest(RewriterBase &rewriter, scf::ForallOp forallOp, + ArrayRef scopes) { + if (scopes.empty()) { + return forallOp.emitError("at least one scope required"); + } + + if (failed(matchForallConversion(forallOp))) { + return failure(); + } + + return convertForallToGenericNestImpl(rewriter, forallOp, scopes); +} + +} // namespace mlir::iree_compiler::IREE::PCF diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.td b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.td index b00926b6e503..f0d68448d9e2 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Passes.td @@ -9,24 +9,25 @@ include "mlir/Pass/PassBase.td" -def ConvertForallToLoopsPass : Pass<"iree-pcf-convert-forall-to-loops", ""> { - let summary = "Converts scf.forall ops to pcf.loop"; +def TestConvertForallToLoopsPass : Pass<"iree-pcf-test-convert-forall-to-loops", ""> { + let summary = "Test pass for convertForallToPCFLoop transform"; let description = [{ - Test pass for converting `scf.forall` ops without mapping attributes to - `pcf.loop` ops with sequential scope. + Test pass for converting `scf.forall` ops without mapping attributes or + with `iree_codegen.local_mapping` attributes to `pcf.loop` ops with + sequential scope. The input is IR containing `scf.forall` ops with tensor results and `tensor.parallel_insert_slice` terminators. Only forall ops without - mapping attributes are converted. + mapping attributes or with local_mapping attributes are converted. The output replaces each matching `scf.forall` with a `pcf.loop`: - - Iteration bounds come from the forall's upper/lower bounds and steps - - `tensor.parallel_insert_slice` ops become `pcf.write_slice` ops - - Shared output tensors become tied `pcf.sref` region arguments - - The scope is set to `#pcf.sequential` for sequential execution + - Iteration bounds come from the forall's upper/lower bounds and steps. + - `tensor.parallel_insert_slice` ops become `pcf.write_slice` ops. + - Shared output tensors become tied `pcf.sref` region arguments. + - The scope is set to `#pcf.sequential` for sequential execution. The underlying conversion pattern is exposed separately via - `convertForallToPCF()` with callbacks for mapping processor IDs to custom + `convertForallToPCFLoop()` with callbacks for mapping processor IDs to custom execution scopes. }]; let dependentDialects = ["::mlir::iree_compiler::IREE::PCF::PCFDialect"]; @@ -154,4 +155,27 @@ def ResolveTokensPass : Pass<"iree-pcf-resolve-tokens", ""> { }]; } +def TestConvertForallToGenericNestPass : + InterfacePass<"iree-pcf-test-convert-forall-to-generic-nest", "mlir::FunctionOpInterface"> { + let summary = "Test pass for convertForallToGenericNest transform"; + let description = [{ + Test pass that converts `scf.forall` ops to a nest of `pcf.generic` ops + with an inner `scf.forall` handling spillover iterations. Each scope + generates one `pcf.generic` op (outermost first). Worker IDs are + linearized, chunk bounds computed for work distribution, and + delinearized back to multi-dimensional forall bounds. + }]; + let options = [ + Option<"numSequentialScopes", "num-sequential-scopes", "int64_t", + /*default=*/"1", + "Number of pcf.sequential scopes to use"> + ]; + let dependentDialects = [ + "::mlir::iree_compiler::IREE::PCF::PCFDialect", + "::mlir::arith::ArithDialect", + "::mlir::affine::AffineDialect", + "::mlir::scf::SCFDialect" + ]; +} + #endif // IREE_CODEGEN_DIALECT_PCF_TRANSFORMS_PASSES diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h index 4ef753c7b0f0..43810a707921 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h @@ -29,14 +29,28 @@ class ExtractSliceOp; namespace mlir::iree_compiler::IREE::PCF { -// Helper to convert scf.forall ops to pcf.loop by linearizing/delinearizing -// ids beyond |numIds| into the slowest varying id. Uses -// DeviceMappingAttrInterface to infer the order of ids from slowest to fastest -// varying. If |numIds| <= 0, then no linearization/delinearization is done. -FailureOr convertForallToPCF(RewriterBase &rewriter, - scf::ForallOp forallOp, - PCF::ScopeAttrInterface scope, - int64_t numIds = -1); +/// Converts scf.forall ops to pcf.loop by linearizing/delinearizing ids beyond +/// |numIds| into the slowest varying id. Uses DeviceMappingAttrInterface to +/// infer the order of ids from slowest to fastest varying. If |numIds| <= 0, +/// then no linearization/delinearization is done. +FailureOr convertForallToPCFLoop(RewriterBase &rewriter, + scf::ForallOp forallOp, + PCF::ScopeAttrInterface scope, + int64_t numIds = -1); + +/// Converts an scf.forall operation to a nest of pcf.generic operations with +/// an inner scf.forall handling spillover iterations. +/// +/// Each scope in |scopes| generates one pcf.generic op (outermost first). +/// Worker IDs are linearized, chunk bounds computed for work distribution, +/// and delinearized back to multi-dimensional forall bounds. Uneven chunks +/// are distributed greedily, with only the last worker lexicographically +/// getting fewer loop iterations. +/// +/// Returns the outermost pcf.generic op on success. +FailureOr +convertForallToGenericNest(RewriterBase &rewriter, scf::ForallOp forallOp, + ArrayRef scopes); struct ConsumerFusionParams { // List of operands in the consumer that are fused along. diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel index 49034482436a..3a47572cf1b3 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/BUILD.bazel @@ -19,6 +19,7 @@ iree_lit_test_suite( srcs = enforce_glob( # keep sorted [ + "convert_forall_to_generic_nest.mlir", "convert_forall_to_loops.mlir", "convert_sref_to_memref.mlir", "fuse_consumers.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/CMakeLists.txt index b1d0feff757c..bf5cc93e354f 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/CMakeLists.txt @@ -14,6 +14,7 @@ iree_lit_test_suite( NAME lit SRCS + "convert_forall_to_generic_nest.mlir" "convert_forall_to_loops.mlir" "convert_sref_to_memref.mlir" "fuse_consumers.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/convert_forall_to_generic_nest.mlir b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/convert_forall_to_generic_nest.mlir new file mode 100644 index 000000000000..fafa81aa41db --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/convert_forall_to_generic_nest.mlir @@ -0,0 +1,177 @@ +// RUN: iree-opt %s --pass-pipeline="builtin.module(func.func(iree-pcf-test-convert-forall-to-generic-nest{num-sequential-scopes=1}))" --split-input-file | FileCheck %s --check-prefix=CHECK-1SCOPE +// RUN: iree-opt %s --pass-pipeline="builtin.module(func.func(iree-pcf-test-convert-forall-to-generic-nest{num-sequential-scopes=2}))" --split-input-file | FileCheck %s --check-prefix=CHECK-2SCOPE + +// Single scope: one pcf.generic wrapping scf.forall. +// Two scopes: nested pcf.generic ops with linearized ids. + +// CHECK-1SCOPE-LABEL: func.func @test_1d_single_scope +// CHECK-1SCOPE-SAME: %[[INIT:.+]]: tensor<64xf32> +// CHECK-1SCOPE: %[[RESULT:.+]] = pcf.generic +// CHECK-1SCOPE-SAME: scope(#pcf.sequential) +// CHECK-1SCOPE: execute(%[[REF:.+]] = %[[INIT]])[%[[ID:.+]]: index, %[[COUNT:.+]]: index] +// Tile size computed from total / count. +// CHECK-1SCOPE: %[[TILE_SIZE:.+]] = arith.ceildivui %{{.+}}, %[[COUNT]] +// Bounds: start = id * tile_size. +// CHECK-1SCOPE: %[[START:.+]] = arith.muli %[[ID]], %[[TILE_SIZE]] +// CHECK-1SCOPE: %[[END_UNCLAMPED:.+]] = arith.addi %[[START]], %[[TILE_SIZE]] +// CHECK-1SCOPE: %[[END:.+]] = arith.minui %[[END_UNCLAMPED]] +// Forall from start to end. +// CHECK-1SCOPE: scf.forall (%[[IV:.+]]) = (%[[START]]) to (%[[END]]) +// CHECK-1SCOPE: %[[SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] +// CHECK-1SCOPE: pcf.write_slice %[[SLICE]] into %[[REF]][%[[IV]]] +// CHECK-1SCOPE: pcf.return +// CHECK-1SCOPE: return %[[RESULT]] + +// CHECK-2SCOPE-LABEL: func.func @test_1d_single_scope +// CHECK-2SCOPE-SAME: %[[INIT:.+]]: tensor<64xf32> +// CHECK-2SCOPE: %[[RESULT:.+]] = pcf.generic +// CHECK-2SCOPE-SAME: scope(#pcf.sequential) +// CHECK-2SCOPE: execute(%[[REF:.+]] = %[[INIT]])[%[[ID0:.+]]: index, %[[COUNT0:.+]]: index] +// CHECK-2SCOPE: pcf.generic +// CHECK-2SCOPE-SAME: scope(#pcf.sequential) +// CHECK-2SCOPE: execute[%[[ID1:.+]]: index, %[[COUNT1:.+]]: index] +// Linearize the two scope IDs. +// CHECK-2SCOPE: %[[LIN_ID:.+]] = affine.linearize_index [%[[ID0]], %[[ID1]]] by (%[[COUNT0]], %[[COUNT1]]) +// CHECK-2SCOPE: %[[TOTAL_COUNT:.+]] = arith.muli %[[COUNT0]], %[[COUNT1]] +// Tile size computed from total / linearized_count. +// CHECK-2SCOPE: %[[TILE_SIZE:.+]] = arith.ceildivui %{{.+}}, %[[TOTAL_COUNT]] +// Bounds using linearized ID. +// CHECK-2SCOPE: %[[START:.+]] = arith.muli %[[LIN_ID]], %[[TILE_SIZE]] +// CHECK-2SCOPE: %[[END_UNCLAMPED:.+]] = arith.addi %[[START]], %[[TILE_SIZE]] +// CHECK-2SCOPE: %[[END:.+]] = arith.minui %[[END_UNCLAMPED]] +// Forall from start to end. +// CHECK-2SCOPE: scf.forall (%[[IV:.+]]) = (%[[START]]) to (%[[END]]) +// CHECK-2SCOPE: pcf.write_slice %{{.+}} into %[[REF]][%[[IV]]] +// CHECK-2SCOPE: pcf.return +// CHECK-2SCOPE: pcf.return +// CHECK-2SCOPE: return %[[RESULT]] +func.func @test_1d_single_scope(%init: tensor<64xf32>) -> tensor<64xf32> { + %result = scf.forall (%i) in (64) shared_outs(%out = %init) -> tensor<64xf32> { + %slice = tensor.extract_slice %out[%i] [1] [1] : tensor<64xf32> to tensor<1xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %slice into %out[%i] [1] [1] : tensor<1xf32> into tensor<64xf32> + } + } {mapping = [#iree_codegen.local_mapping<0>]} + return %result : tensor<64xf32> +} + +// ----- + +// CHECK-1SCOPE-LABEL: func.func @test_2d_multi_dim +// CHECK-1SCOPE-SAME: %[[INIT:.+]]: tensor<64x128xf32> +// CHECK-1SCOPE: pcf.generic +// CHECK-1SCOPE: execute(%[[REF:.+]] = %[[INIT]])[%[[ID:.+]]: index, %[[COUNT:.+]]: index] +// Tile size from total elements / count. +// CHECK-1SCOPE: %[[TILE_SIZE:.+]] = arith.ceildivui %{{.+}}, %[[COUNT]] +// CHECK-1SCOPE: %[[START:.+]] = arith.muli %[[ID]], %[[TILE_SIZE]] +// CHECK-1SCOPE: %[[END_UNCLAMPED:.+]] = arith.addi %[[START]], %[[TILE_SIZE]] +// CHECK-1SCOPE: %[[END:.+]] = arith.minui %[[END_UNCLAMPED]] +// CHECK-1SCOPE: scf.forall (%[[IV:.+]]) = (%[[START]]) to (%[[END]]) +// Delinearize into (64, 128) to recover 2D indices. +// CHECK-1SCOPE: %[[INDICES:.+]]:2 = affine.delinearize_index %[[IV]] into (64, 128) +// CHECK-1SCOPE: %[[SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[INDICES]]#0, %[[INDICES]]#1] +// CHECK-1SCOPE: pcf.write_slice %[[SLICE]] into %[[REF]][%[[INDICES]]#0, %[[INDICES]]#1] +// CHECK-1SCOPE: pcf.return + +// CHECK-2SCOPE-LABEL: func.func @test_2d_multi_dim +// CHECK-2SCOPE: pcf.generic +// CHECK-2SCOPE-SAME: scope(#pcf.sequential) +// CHECK-2SCOPE: execute(%[[REF:.+]] = %{{.+}})[%[[ID0:.+]]: index, %[[COUNT0:.+]]: index] +// CHECK-2SCOPE: pcf.generic +// CHECK-2SCOPE-SAME: scope(#pcf.sequential) +// CHECK-2SCOPE: execute[%[[ID1:.+]]: index, %[[COUNT1:.+]]: index] +// CHECK-2SCOPE: %[[LIN_ID:.+]] = affine.linearize_index [%[[ID0]], %[[ID1]]] by (%[[COUNT0]], %[[COUNT1]]) +// CHECK-2SCOPE: %[[TOTAL_COUNT:.+]] = arith.muli %[[COUNT0]], %[[COUNT1]] +// CHECK-2SCOPE: scf.forall (%[[IV:.+]]) = +// Delinearize the forall IV to recover 2D indices. +// CHECK-2SCOPE: %[[INDICES:.+]]:2 = affine.delinearize_index %[[IV]] into (64, 128) +// CHECK-2SCOPE: pcf.write_slice %{{.+}} into %[[REF]][%[[INDICES]]#0, %[[INDICES]]#1] +// CHECK-2SCOPE: pcf.return +// CHECK-2SCOPE: pcf.return +func.func @test_2d_multi_dim(%init: tensor<64x128xf32>) -> tensor<64x128xf32> { + %result = scf.forall (%i, %j) in (64, 128) shared_outs(%out = %init) -> tensor<64x128xf32> { + %slice = tensor.extract_slice %out[%i, %j] [1, 1] [1, 1] : tensor<64x128xf32> to tensor<1x1xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %slice into %out[%i, %j] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<64x128xf32> + } + } {mapping = [#iree_codegen.local_mapping<1>, #iree_codegen.local_mapping<0>]} + return %result : tensor<64x128xf32> +} + +// ----- + +// Test mapping permutation where dimensions are NOT in natural order. +// Forall dimensions: (%i:4, %j:8) with mapping [local_mapping<0>, local_mapping<1>]. +// This means: dim 0 (i) corresponds to id 0 (fast), dim 1 (j) to id 1 (slow). +// Linearization should be: j * 4 + i (not i * 8 + j). +// Delinearization with basis (8, 4) gives (j, i) which must be permuted. + +// CHECK-1SCOPE-LABEL: func.func @test_permutation +// CHECK-1SCOPE-SAME: %[[INIT:.+]]: tensor<4x8xf32> +// CHECK-1SCOPE: pcf.generic +// CHECK-1SCOPE: execute(%[[REF:.+]] = %[[INIT]])[%[[ID:.+]]: index, %[[COUNT:.+]]: index] +// CHECK-1SCOPE: scf.forall (%[[IV:.+]]) = +// Delinearization basis should be (8, 4) - slow dimension first. +// CHECK-1SCOPE: %[[INDICES:.+]]:2 = affine.delinearize_index %[[IV]] into (8, 4) +// Permuted indices: [#1, #0] maps (j, i) back to (i, j) for tensor access. +// CHECK-1SCOPE: %[[SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[INDICES]]#1, %[[INDICES]]#0] +// CHECK-1SCOPE: pcf.write_slice %[[SLICE]] into %[[REF]][%[[INDICES]]#1, %[[INDICES]]#0] +// CHECK-1SCOPE: pcf.return +func.func @test_permutation(%init: tensor<4x8xf32>) -> tensor<4x8xf32> { + %result = scf.forall (%i, %j) in (4, 8) shared_outs(%out = %init) -> tensor<4x8xf32> { + %slice = tensor.extract_slice %out[%i, %j] [1, 1] [1, 1] : tensor<4x8xf32> to tensor<1x1xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %slice into %out[%i, %j] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<4x8xf32> + } + } {mapping = [#iree_codegen.local_mapping<0>, #iree_codegen.local_mapping<1>]} + return %result : tensor<4x8xf32> +} + +// ----- + +// Test 3D local_mapping with permutation [1, 2, 0]. +// This means: dim0 -> position 1, dim1 -> position 2, dim2 -> position 0. +// So iteration order from fastest to slowest is: dim2, dim0, dim1. +// Bounds are (4, 8, 16), so delinearization basis should be (8, 4, 16) +// ordered from slowest to fastest: dim1 (8), dim0 (4), dim2 (16). + +// CHECK-1SCOPE-LABEL: func.func @test_3d_permutation +// CHECK-1SCOPE-SAME: %[[INIT:.+]]: tensor<4x8x16xf32> +// CHECK-1SCOPE: pcf.generic +// CHECK-1SCOPE: execute(%[[REF:.+]] = %[[INIT]])[%[[ID:.+]]: index, %[[COUNT:.+]]: index] +// CHECK-1SCOPE: scf.forall (%[[IV:.+]]) = +// Delinearization basis should be (8, 4, 16) - slowest to fastest. +// CHECK-1SCOPE: %[[INDICES:.+]]:3 = affine.delinearize_index %[[IV]] into (8, 4, 16) +// Permuted indices: [#1, #0, #2] maps back to original dim order. +// CHECK-1SCOPE: %[[SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[INDICES]]#1, %[[INDICES]]#0, %[[INDICES]]#2] +// CHECK-1SCOPE: pcf.write_slice %[[SLICE]] into %[[REF]][%[[INDICES]]#1, %[[INDICES]]#0, %[[INDICES]]#2] +// CHECK-1SCOPE: pcf.return +func.func @test_3d_permutation(%init: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { + %result = scf.forall (%i, %j, %k) in (4, 8, 16) shared_outs(%out = %init) -> tensor<4x8x16xf32> { + %slice = tensor.extract_slice %out[%i, %j, %k] [1, 1, 1] [1, 1, 1] : tensor<4x8x16xf32> to tensor<1x1x1xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %slice into %out[%i, %j, %k] [1, 1, 1] [1, 1, 1] : tensor<1x1x1xf32> into tensor<4x8x16xf32> + } + } {mapping = [#iree_codegen.local_mapping<1>, #iree_codegen.local_mapping<2>, #iree_codegen.local_mapping<0>]} + return %result : tensor<4x8x16xf32> +} + +// ----- + +// CHECK-1SCOPE-LABEL: func.func @test_empty_mapping +// CHECK-1SCOPE-SAME: %[[INIT:.+]]: tensor<64xf32> +// CHECK-1SCOPE: pcf.generic +// CHECK-1SCOPE: execute(%[[REF:.+]] = %[[INIT]])[%[[ID:.+]]: index, %[[COUNT:.+]]: index] +// CHECK-1SCOPE: scf.forall (%[[IV:.+]]) = +// CHECK-1SCOPE: %[[SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]]] +// CHECK-1SCOPE: pcf.write_slice %[[SLICE]] into %[[REF]][%[[IV]]] +// CHECK-1SCOPE: pcf.return +func.func @test_empty_mapping(%init: tensor<64xf32>) -> tensor<64xf32> { + %result = scf.forall (%i) in (64) shared_outs(%out = %init) -> tensor<64xf32> { + %slice = tensor.extract_slice %out[%i] [1] [1] : tensor<64xf32> to tensor<1xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %slice into %out[%i] [1] [1] : tensor<1xf32> into tensor<64xf32> + } + } + return %result : tensor<64xf32> +} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/convert_forall_to_loops.mlir b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/convert_forall_to_loops.mlir index e18a07e98453..b74676886bcb 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/convert_forall_to_loops.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms/test/convert_forall_to_loops.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt %s --pass-pipeline="builtin.module(iree-pcf-convert-forall-to-loops)" --split-input-file | FileCheck %s +// RUN: iree-opt %s --pass-pipeline="builtin.module(iree-pcf-test-convert-forall-to-loops)" --split-input-file | FileCheck %s func.func @convert_forall(%arg0: tensor, %init: tensor, %d0: index) -> tensor { %0 = scf.forall (%id0, %id1) in (%d0, 32) shared_outs(%iter = %init) -> (tensor) { diff --git a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CodegenExternalModels.cpp b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CodegenExternalModels.cpp index 4ed7454fe3ae..670f399381de 100644 --- a/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CodegenExternalModels.cpp +++ b/compiler/src/iree/compiler/Codegen/ExternalInterfaces/CodegenExternalModels.cpp @@ -112,6 +112,13 @@ struct WorkgroupScopeAttrModel final // Allocating workgroup memory unsupported. return failure(); } + + int64_t getNativeNumProcessorIds(Attribute attr) const { + auto workgroupScopeAttr = cast(attr); + // When linearize is true, all IDs are combined into one. + // When false, we have 3 native IDs (x, y, z). + return workgroupScopeAttr.getLinearize() ? 1 : 3; + } }; class CodegenPCFConversionInterface : public PCFConversionDialectInterface {