Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ iree_compiler_cc_library(
"ConvertAccGEMMToGEMMPass.cpp",
"ConvertBf16ArithToF32.cpp",
"ConvertBf16ToUInt16Buffers.cpp",
"ConvertForallToGenericNestWorkgroup.cpp",
"ConvertToDestinationPassingStylePass.cpp",
"ConvertUnsupportedFloatArithPass.cpp",
"ConvertWorkgroupForallToPCF.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ iree_cc_library(
"ConvertAccGEMMToGEMMPass.cpp"
"ConvertBf16ArithToF32.cpp"
"ConvertBf16ToUInt16Buffers.cpp"
"ConvertForallToGenericNestWorkgroup.cpp"
"ConvertToDestinationPassingStylePass.cpp"
"ConvertUnsupportedFloatArithPass.cpp"
"ConvertWorkgroupForallToPCF.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright 2026 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Dialect/PCF/IR/PCF.h"
#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/PatternMatch.h"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_CONVERTFORALLTOGENERICNESTWORKGROUPPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

namespace {

/// Returns true if the forall op has WorkgroupMappingAttr mapping attributes.
static bool hasWorkgroupMapping(scf::ForallOp forallOp) {
std::optional<ArrayAttr> mapping = forallOp.getMapping();
if (!mapping || mapping->empty()) {
return false;
}
return llvm::all_of(mapping.value(),
llvm::IsaPred<IREE::Codegen::WorkgroupMappingAttr>);
}

struct ConvertForallToGenericNestWorkgroupPass final
: public impl::ConvertForallToGenericNestWorkgroupPassBase<
ConvertForallToGenericNestWorkgroupPass> {
using Base::Base;

void runOnOperation() override {
MLIRContext *ctx = &getContext();

// Always use linearized workgroup scope (1 id).
// Interface is implemented via external models hence the cast.
auto scope = cast<IREE::PCF::ScopeAttrInterface>(
IREE::Codegen::WorkgroupScopeAttr::get(ctx, /*linearize=*/true));

SmallVector<IREE::PCF::ScopeAttrInterface> scopes = {scope};

IRRewriter rewriter(ctx);
SmallVector<scf::ForallOp> forallOps;
getOperation()->walk([&](scf::ForallOp forallOp) {
// Only convert foralls with workgroup mapping attributes.
if (hasWorkgroupMapping(forallOp)) {
forallOps.push_back(forallOp);
}
});

for (scf::ForallOp forallOp : forallOps) {
rewriter.setInsertionPoint(forallOp);
FailureOr<IREE::PCF::GenericOp> result =
IREE::PCF::convertForallToGenericNest(rewriter, forallOp, scopes);
if (failed(result)) {
forallOp.emitError("failed to convert forall to generic nest");
return signalPassFailure();
}
// Replace forall results with generic results.
rewriter.replaceOp(forallOp, result->getResults());
}
}
};

} // namespace
} // namespace mlir::iree_compiler
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ ConvertWorkgroupForall::matchAndRewrite(scf::ForallOp op,
auto scope = cast<IREE::PCF::ScopeAttrInterface>(
IREE::Codegen::WorkgroupScopeAttr::get(rewriter.getContext(),
/*linearize=*/true));
FailureOr<IREE::PCF::LoopOp> res = convertForallToPCF(rewriter, op, scope, 1);
FailureOr<IREE::PCF::LoopOp> res =
convertForallToPCFLoop(rewriter, op, scope, 1);
if (failed(res)) {
return failure();
}
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
Expand Down
17 changes: 17 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,23 @@ def ConvertWorkgroupForallToPCFPass
let dependentDialects = ["iree_compiler::IREE::PCF::PCFDialect"];
}

def ConvertForallToGenericNestWorkgroupPass
: InterfacePass<"iree-codegen-convert-forall-to-generic-nest-workgroup",
"mlir::FunctionOpInterface"> {
let summary = "Converts scf.forall ops with workgroup mapping to pcf.generic";
let description = [{
Converts `scf.forall` ops with `#iree_codegen.workgroup_mapping` attributes
to a `pcf.generic` op using workgroup scope. The pass always linearizes
workgroup IDs to a single dimension.
}];
let dependentDialects = [
"iree_compiler::IREE::PCF::PCFDialect",
"arith::ArithDialect",
"affine::AffineDialect",
"scf::SCFDialect"
];
}

def CombineLayoutTransformationPass :
InterfacePass<"iree-codegen-combine-layout-transformation", "mlir::FunctionOpInterface"> {
let summary =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ iree_lit_test_suite(
"convert_accgemm_to_gemm.mlir",
"convert_bf16_arith_to_f32.mlir",
"convert_bf16_to_uint16_buffers.mlir",
"convert_forall_to_generic_nest_workgroup.mlir",
"convert_hal_descriptor_type_to_gpu_address_space.mlir",
"convert_to_destination_passing_style.mlir",
"convert_unsupported_float_arith.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ iree_lit_test_suite(
"convert_accgemm_to_gemm.mlir"
"convert_bf16_arith_to_f32.mlir"
"convert_bf16_to_uint16_buffers.mlir"
"convert_forall_to_generic_nest_workgroup.mlir"
"convert_hal_descriptor_type_to_gpu_address_space.mlir"
"convert_to_destination_passing_style.mlir"
"convert_unsupported_float_arith.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: iree-opt %s --pass-pipeline="builtin.module(func.func(iree-codegen-convert-forall-to-generic-nest-workgroup))" --allow-unregistered-dialect --split-input-file | FileCheck %s

// Test that workgroup scope creates 1 id/count pair with linearized scope.

// CHECK-LABEL: func.func @test_workgroup_scope
// CHECK: pcf.generic
// CHECK-SAME: scope(#iree_codegen.workgroup_scope<linearize>)
// CHECK: execute(%{{.*}})[%[[ID:.+]]: index, %[[COUNT:.+]]: index]
// Chunk size computed from total iterations / worker count.
// CHECK: %[[CHUNK:.+]] = arith.ceildivui
// Start = id * chunk_size.
// CHECK: %[[START:.+]] = arith.muli %[[ID]], %[[CHUNK]]
// End = min(start + chunk_size, total).
// CHECK: %[[END_RAW:.+]] = arith.addi %[[START]], %[[CHUNK]]
// CHECK: %[[END:.+]] = arith.minui %[[END_RAW]]
// CHECK: scf.forall (%[[IV:.+]]) = (%[[START]]) to (%[[END]])
// CHECK: "foo.body"(%[[IV]])
// CHECK: pcf.write_slice
// CHECK: pcf.return
func.func @test_workgroup_scope(%init: tensor<64xf32>) -> tensor<64xf32> {
%result = scf.forall (%i) in (64) shared_outs(%out = %init) -> tensor<64xf32> {
"foo.body"(%i) : (index) -> ()
%slice = tensor.extract_slice %out[%i] [1] [1] : tensor<64xf32> to tensor<1xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %slice into %out[%i] [1] [1] : tensor<1xf32> into tensor<64xf32>
}
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
return %result : tensor<64xf32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ struct SubgroupScopeModel
MLIRContext *context) const {
return gpu::AddressSpaceAttr::get(context, gpu::AddressSpace::Workgroup);
}

int64_t getNativeNumProcessorIds(Attribute attr) const {
// SubgroupScope natively provides a single 1D processor ID (subgroup_id).
return 1;
}
};

/// External model for LaneScopeAttr implementing ScopeAttrInterface.
Expand Down Expand Up @@ -133,6 +138,11 @@ struct LaneScopeModel
// logic to allocate + subview.
return failure();
}

int64_t getNativeNumProcessorIds(Attribute attr) const {
// LaneScope natively provides a single 1D processor ID (lane_id).
return 1;
}
};

} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ iree_compiler_cc_library(
name = "GPUTransforms",
srcs = [
"CombineBarrierRegions.cpp",
"ConvertForallToGenericNestGPU.cpp",
"DistributeInnerTiledToLanes.cpp",
"ExpandUndistributedInnerTiles.cpp",
"LowerIREEGPUOps.cpp",
Expand All @@ -77,6 +78,8 @@ iree_compiler_cc_library(
":PassesIncGen",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"//compiler/src/iree/compiler/Codegen/Dialect/PCF/IR",
"//compiler/src/iree/compiler/Codegen/Dialect/PCF/Transforms",
"//compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR:IREEVectorExtDialect",
"//compiler/src/iree/compiler/Codegen/Transforms",
"//compiler/src/iree/compiler/Codegen/Utils",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ iree_cc_library(
"Transforms.h"
SRCS
"CombineBarrierRegions.cpp"
"ConvertForallToGenericNestGPU.cpp"
"DistributeInnerTiledToLanes.cpp"
"ExpandUndistributedInnerTiles.cpp"
"LowerIREEGPUOps.cpp"
Expand Down Expand Up @@ -84,6 +85,8 @@ iree_cc_library(
MLIRVectorUtils
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::Codegen::Dialect::PCF::IR
iree::compiler::Codegen::Dialect::PCF::Transforms
iree::compiler::Codegen::Dialect::VectorExt::IR::IREEVectorExtDialect
iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright 2026 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
#include "iree/compiler/Codegen/Dialect/PCF/IR/PCF.h"
#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/PatternMatch.h"

namespace mlir::iree_compiler::IREE::GPU {

#define GEN_PASS_DEF_CONVERTFORALLTOGENERICNESTGPUPASS
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h.inc"

namespace {

/// Returns true if the forall op has gpu::GPUThreadMappingAttr mapping
/// attributes.
static bool hasThreadMapping(scf::ForallOp forallOp) {
std::optional<ArrayAttr> mapping = forallOp.getMapping();
if (!mapping || mapping->empty()) {
return false;
}
return llvm::all_of(mapping.value(),
llvm::IsaPred<gpu::GPUThreadMappingAttr>);
}

/// Runs the conversion for forall ops with gpu.thread mapping to nested
/// pcf.generic ops.
static LogicalResult runConversion(Operation *op) {
MLIRContext *ctx = op->getContext();

// Create nested scopes: subgroup (outer) + lane (inner).
// Interface is implemented via external models hence the casts.
auto subgroupScope =
cast<PCF::ScopeAttrInterface>(SubgroupScopeAttr::get(ctx));
auto laneScope = cast<PCF::ScopeAttrInterface>(LaneScopeAttr::get(ctx));

SmallVector<PCF::ScopeAttrInterface> scopes = {subgroupScope, laneScope};

IRRewriter rewriter(ctx);
SmallVector<scf::ForallOp> forallOps;
op->walk([&](scf::ForallOp forallOp) {
// Only convert foralls with gpu.thread mapping attributes.
if (hasThreadMapping(forallOp)) {
forallOps.push_back(forallOp);
}
});

for (scf::ForallOp forallOp : forallOps) {
rewriter.setInsertionPoint(forallOp);
FailureOr<PCF::GenericOp> result =
PCF::convertForallToGenericNest(rewriter, forallOp, scopes);
if (failed(result)) {
forallOp.emitError("failed to convert forall to generic nest");
return failure();
}
// Replace forall results with generic results.
rewriter.replaceOp(forallOp, result->getResults());
}
return success();
}

struct ConvertForallToGenericNestGPUPass final
: public impl::ConvertForallToGenericNestGPUPassBase<
ConvertForallToGenericNestGPUPass> {
using Base::Base;

void runOnOperation() override {
if (failed(runConversion(getOperation()))) {
return signalPassFailure();
}
}
};

} // namespace
} // namespace mlir::iree_compiler::IREE::GPU
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,21 @@ def VectorizeIREEGPUOpsPass :
];
}

def ConvertForallToGenericNestGPUPass :
InterfacePass<"iree-gpu-convert-forall-to-generic-nest",
"mlir::FunctionOpInterface"> {
let summary = "Converts scf.forall ops with GPU mapping to pcf.generic";
let description = [{
Converts scf.forall ops with gpu.thread mapping to nested pcf.generic ops
using subgroup (outer) and lane (inner) scopes.
}];
let dependentDialects = [
"::mlir::iree_compiler::IREE::PCF::PCFDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
"::mlir::arith::ArithDialect",
"::mlir::affine::AffineDialect",
"::mlir::scf::SCFDialect",
];
}

#endif // IREE_CODEGEN_DIALECt_GPU_TRANSFORMS_PASSES
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ iree_lit_test_suite(
# keep sorted
[
"combine_barrier_regions.mlir",
"convert_forall_to_generic_nest_gpu.mlir",
"distribute_inner_tiled_to_lanes.mlir",
"expand_undistributed_inner_tiles.mlir",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_lit_test_suite(
lit
SRCS
"combine_barrier_regions.mlir"
"convert_forall_to_generic_nest_gpu.mlir"
"distribute_inner_tiled_to_lanes.mlir"
"expand_undistributed_inner_tiles.mlir"
TOOLS
Expand Down
Loading
Loading