Skip to content

Commit c938b03

Browse files
committed
[Codegen][PCF] Add convertForallToGenericNest transform
Add a transform that converts scf.forall operations into multi-level pcf.generic nests. This is required in cases where a single scf.forall mapping type needs to map to multiple scopes. The immediate case this arises from is converting thread mapped scf.forall ops to combined subgroup + lane scopes. In this case, we can't simply convert to pcf.loop because of the way that automatic redistribution of pcf.loop's lowering works. We need to redistribute the iterations of the scf.forall to *all* workers across both scopes. This means that if we make one of the two scopes (subgroup or lane) a pcf.loop this fails to create the correct IR structure. If we make it loop over subgroups, this fails to predicate the lanes, and if we make it loop over lanes, we end up with more thread divergence than normal (normally only a single subgroup may exhibit divergence). The transform creates an outer loop over workers and an inner scf.forall over the per-worker iteration range. For multi-dimensional cases, affine.linearize_index and affine.delinearize_index are used to flatten/unflatten indices appropriately. Additionally adds a new method getNativeNumProcessorIds to ScopeAttrInterface. This is needed to query for the number of ids to generate.
1 parent ab71aab commit c938b03

29 files changed

+1222
-334
lines changed

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ iree_compiler_cc_library(
161161
"TensorToVectorVectorizePad.cpp",
162162
"TestExecutablePreprocessing.cpp",
163163
"TestPartitionableLoopsInterface.cpp",
164+
"ConvertForallToGenericNestWorkgroup.cpp",
164165
"TileAndDistributeToWorkgroupsPass.cpp",
165166
"TileAndFuseUtils.cpp",
166167
"TileDispatchUsingForall.cpp",

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ iree_cc_library(
154154
"TensorToVectorVectorizePad.cpp"
155155
"TestExecutablePreprocessing.cpp"
156156
"TestPartitionableLoopsInterface.cpp"
157+
"ConvertForallToGenericNestWorkgroup.cpp"
157158
"TileAndDistributeToWorkgroupsPass.cpp"
158159
"TileAndFuseUtils.cpp"
159160
"TileDispatchUsingForall.cpp"
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Copyright 2026 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Common/Passes.h"
8+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
9+
#include "iree/compiler/Codegen/Dialect/PCF/IR/PCF.h"
10+
#include "iree/compiler/Codegen/Dialect/PCF/Transforms/Transforms.h"
11+
#include "mlir/Dialect/SCF/IR/SCF.h"
12+
#include "mlir/IR/PatternMatch.h"
13+
14+
namespace mlir::iree_compiler {
15+
16+
#define GEN_PASS_DEF_CONVERTFORALLTOGENERICNESTWORKGROUPPASS
17+
#include "iree/compiler/Codegen/Common/Passes.h.inc"
18+
19+
namespace {
20+
21+
/// Returns true if the forall op has WorkgroupMappingAttr mapping attributes.
22+
static bool hasWorkgroupMapping(scf::ForallOp forallOp) {
23+
std::optional<ArrayAttr> mapping = forallOp.getMapping();
24+
if (!mapping || mapping->empty()) {
25+
return false;
26+
}
27+
return llvm::all_of(mapping.value(),
28+
llvm::IsaPred<IREE::Codegen::WorkgroupMappingAttr>);
29+
}
30+
31+
struct ConvertForallToGenericNestWorkgroupPass final
32+
: public impl::ConvertForallToGenericNestWorkgroupPassBase<
33+
ConvertForallToGenericNestWorkgroupPass> {
34+
using Base::Base;
35+
36+
void runOnOperation() override {
37+
MLIRContext *ctx = &getContext();
38+
39+
// Always use linearized workgroup scope (1 id).
40+
// Interface is implemented via external models hence the cast.
41+
auto scope = cast<IREE::PCF::ScopeAttrInterface>(
42+
IREE::Codegen::WorkgroupScopeAttr::get(ctx, /*linearize=*/true));
43+
44+
SmallVector<IREE::PCF::ScopeAttrInterface> scopes = {scope};
45+
46+
IRRewriter rewriter(ctx);
47+
SmallVector<scf::ForallOp> forallOps;
48+
getOperation()->walk([&](scf::ForallOp forallOp) {
49+
// Only convert foralls with workgroup mapping attributes.
50+
if (hasWorkgroupMapping(forallOp)) {
51+
forallOps.push_back(forallOp);
52+
}
53+
});
54+
55+
for (scf::ForallOp forallOp : forallOps) {
56+
rewriter.setInsertionPoint(forallOp);
57+
FailureOr<IREE::PCF::GenericOp> result =
58+
IREE::PCF::convertForallToGenericNest(rewriter, forallOp, scopes);
59+
if (failed(result)) {
60+
forallOp.emitError("failed to convert forall to generic nest");
61+
return signalPassFailure();
62+
}
63+
// Replace forall results with generic results.
64+
rewriter.replaceOp(forallOp, result->getResults());
65+
}
66+
}
67+
};
68+
69+
} // namespace
70+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/ConvertWorkgroupForallToPCF.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ ConvertWorkgroupForall::matchAndRewrite(scf::ForallOp op,
5353
auto scope = cast<IREE::PCF::ScopeAttrInterface>(
5454
IREE::Codegen::WorkgroupScopeAttr::get(rewriter.getContext(),
5555
/*linearize=*/true));
56-
FailureOr<IREE::PCF::LoopOp> res = convertForallToPCF(rewriter, op, scope, 1);
56+
FailureOr<IREE::PCF::LoopOp> res =
57+
convertForallToPCFLoop(rewriter, op, scope, 1);
5758
if (failed(res)) {
5859
return failure();
5960
}

compiler/src/iree/compiler/Codegen/Common/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1919
#include "iree/compiler/Codegen/Utils/Utils.h"
2020
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
21+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
2122
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
2223
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2324
#include "mlir/Dialect/Transform/IR/TransformOps.h"

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,23 @@ def ConvertWorkgroupForallToPCFPass
155155
let dependentDialects = ["iree_compiler::IREE::PCF::PCFDialect"];
156156
}
157157

158+
def ConvertForallToGenericNestWorkgroupPass
159+
: InterfacePass<"iree-codegen-convert-forall-to-generic-nest-workgroup",
160+
"mlir::FunctionOpInterface"> {
161+
let summary = "Converts scf.forall ops with workgroup mapping to pcf.generic";
162+
let description = [{
163+
Converts `scf.forall` ops with `#iree_codegen.workgroup_mapping` attributes
164+
to a `pcf.generic` op using workgroup scope. The pass always linearizes
165+
workgroup IDs to a single dimension.
166+
}];
167+
let dependentDialects = [
168+
"iree_compiler::IREE::PCF::PCFDialect",
169+
"arith::ArithDialect",
170+
"affine::AffineDialect",
171+
"scf::SCFDialect"
172+
];
173+
}
174+
158175
def CombineLayoutTransformationPass :
159176
InterfacePass<"iree-codegen-combine-layout-transformation", "mlir::FunctionOpInterface"> {
160177
let summary =

compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ iree_lit_test_suite(
3333
"convert_accgemm_to_gemm.mlir",
3434
"convert_bf16_arith_to_f32.mlir",
3535
"convert_bf16_to_uint16_buffers.mlir",
36+
"convert_forall_to_generic_nest_workgroup.mlir",
3637
"convert_hal_descriptor_type_to_gpu_address_space.mlir",
3738
"convert_to_destination_passing_style.mlir",
3839
"convert_unsupported_float_arith.mlir",

compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ iree_lit_test_suite(
2828
"convert_accgemm_to_gemm.mlir"
2929
"convert_bf16_arith_to_f32.mlir"
3030
"convert_bf16_to_uint16_buffers.mlir"
31+
"convert_forall_to_generic_nest_workgroup.mlir"
3132
"convert_hal_descriptor_type_to_gpu_address_space.mlir"
3233
"convert_to_destination_passing_style.mlir"
3334
"convert_unsupported_float_arith.mlir"
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: iree-opt %s --pass-pipeline="builtin.module(func.func(iree-codegen-convert-forall-to-generic-nest-workgroup))" --allow-unregistered-dialect --split-input-file | FileCheck %s
2+
3+
// Test that workgroup scope creates 1 id/count pair with linearized scope.
4+
5+
// CHECK-LABEL: func.func @test_workgroup_scope
6+
// CHECK: pcf.generic
7+
// CHECK-SAME: scope(#iree_codegen.workgroup_scope<linearize>)
8+
// CHECK: execute(%{{.*}})[%[[ID:.+]]: index, %[[COUNT:.+]]: index]
9+
// Chunk size computed from total iterations / worker count.
10+
// CHECK: %[[CHUNK:.+]] = arith.ceildivui
11+
// Start = id * chunk_size.
12+
// CHECK: %[[START:.+]] = arith.muli %[[ID]], %[[CHUNK]]
13+
// End = min(start + chunk_size, total).
14+
// CHECK: %[[END_RAW:.+]] = arith.addi %[[START]], %[[CHUNK]]
15+
// CHECK: %[[END:.+]] = arith.minui %[[END_RAW]]
16+
// CHECK: scf.forall (%[[IV:.+]]) = (%[[START]]) to (%[[END]])
17+
// CHECK: "foo.body"(%[[IV]])
18+
// CHECK: pcf.write_slice
19+
// CHECK: pcf.return
20+
func.func @test_workgroup_scope(%init: tensor<64xf32>) -> tensor<64xf32> {
21+
%result = scf.forall (%i) in (64) shared_outs(%out = %init) -> tensor<64xf32> {
22+
"foo.body"(%i) : (index) -> ()
23+
%slice = tensor.extract_slice %out[%i] [1] [1] : tensor<64xf32> to tensor<1xf32>
24+
scf.forall.in_parallel {
25+
tensor.parallel_insert_slice %slice into %out[%i] [1] [1] : tensor<1xf32> into tensor<64xf32>
26+
}
27+
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
28+
return %result : tensor<64xf32>
29+
}

compiler/src/iree/compiler/Codegen/Dialect/GPU/ExternalInterfaces/GPUScopeExternalModels.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ struct SubgroupScopeModel
8383
MLIRContext *context) const {
8484
return gpu::AddressSpaceAttr::get(context, gpu::AddressSpace::Workgroup);
8585
}
86+
87+
int64_t getNativeNumProcessorIds(Attribute attr) const {
88+
// SubgroupScope natively provides a single 1D processor ID (subgroup_id).
89+
return 1;
90+
}
8691
};
8792

8893
/// External model for LaneScopeAttr implementing ScopeAttrInterface.
@@ -133,6 +138,11 @@ struct LaneScopeModel
133138
// logic to allocate + subview.
134139
return failure();
135140
}
141+
142+
int64_t getNativeNumProcessorIds(Attribute attr) const {
143+
// LaneScope natively provides a single 1D processor ID (lane_id).
144+
return 1;
145+
}
136146
};
137147

138148
} // namespace

0 commit comments

Comments
 (0)