Skip to content

Commit 8c6194f

Browse files
Simplify the resolution of scf.forall created by split reductions. (iree-org#21422)
Earlier implementation of resolution of the `scf.forall` turned out to be more involved than it needed to be. A simpler solution was to fold the `scf.forall` into the workgroup mapping loop and increase the rank of the loop appropriately. This PR implements that solution. The current approach is kept as well, but disabled by default. It also kicks in when there is no workgroup mapping for the split-reduction loop to merge with, though this probably also needs to be fixed. The reason for keeping the current approach is that folding the split-reduction `scf.forall` and workgroup-mapped `scf.forall` effectively "sinks" the computation with the split-reduction loop (but outside the workgroup mapped loop) into the workgroup mapped loop. This might not always be desirable, so the previous approach is kept as is for now. Signed-off-by: MaheshRavishankar <mravisha@amd.com>
1 parent 28c2301 commit 8c6194f

14 files changed

+728
-41
lines changed

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ iree_compiler_cc_library(
124124
"FlattenMemRefSubspanPass.cpp",
125125
"FlattenMemRefs.cpp",
126126
"FoldAffineMinInDistributedLoops.cpp",
127+
"FoldSplitReductionAndWorkgroupMappingLoopsPass.cpp",
127128
"FoldTensorExtractOpPass.cpp",
128129
"FoldTensorSubsetIntoVectorTransferOps.cpp",
129130
"ForOpCanonicalizationPass.cpp",

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ iree_cc_library(
116116
"FlattenMemRefSubspanPass.cpp"
117117
"FlattenMemRefs.cpp"
118118
"FoldAffineMinInDistributedLoops.cpp"
119+
"FoldSplitReductionAndWorkgroupMappingLoopsPass.cpp"
119120
"FoldTensorExtractOpPass.cpp"
120121
"FoldTensorSubsetIntoVectorTransferOps.cpp"
121122
"ForOpCanonicalizationPass.cpp"
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include "iree/compiler/Codegen/Common/Passes.h"
8+
#include "iree/compiler/Codegen/Transforms/Transforms.h"
9+
#include "mlir/Dialect/SCF/IR/SCF.h"
10+
#include "mlir/IR/PatternMatch.h"
11+
#include "mlir/Pass/Pass.h"
12+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13+
14+
namespace mlir::iree_compiler {
15+
16+
#define GEN_PASS_DEF_FOLDSPLITREDUCTIONANDWORKGROUPMAPPINGLOOPSPASS
17+
#include "iree/compiler/Codegen/Common/Passes.h.inc"
18+
19+
namespace {
20+
21+
struct FoldSplitReductionAndWorkgroupMappingLoopsPass
22+
: public impl::FoldSplitReductionAndWorkgroupMappingLoopsPassBase<
23+
FoldSplitReductionAndWorkgroupMappingLoopsPass> {
24+
using Base::Base;
25+
26+
void runOnOperation() override;
27+
};
28+
29+
void FoldSplitReductionAndWorkgroupMappingLoopsPass::runOnOperation() {
30+
MLIRContext *context = &getContext();
31+
Operation *op = getOperation();
32+
33+
RewritePatternSet patterns(context);
34+
populateFoldSplitReductionAndWorkgroupMappingLoops(patterns);
35+
if (failed(applyPatternsGreedily(op, std::move(patterns)))) {
36+
op->emitOpError("failed to apply pattern to fold split reduction loop with "
37+
"workgroup for all");
38+
return signalPassFailure();
39+
}
40+
}
41+
42+
} // namespace
43+
44+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,11 @@ def ReconcileTranslationInfoPass
286286
clEnumValN(IREE::Codegen::WorkgroupId::IdY, "y",
287287
"Constrain the workgroup distribution to use only workgroups along x and y."),
288288
clEnumValN(IREE::Codegen::WorkgroupId::IdZ, "z",
289-
"Constrain the workgroup distribution to use only workgroups along x, y and z."))}]>
289+
"Constrain the workgroup distribution to use only workgroups along x, y and z."))}]>,
290+
Option<"foldSplitReductionLoopIntoWorkgroupMappingLoop",
291+
"fold-split-reduction-loop-into-workgroup-mapping-loop",
292+
"bool", /*default=*/"true",
293+
"Resolve scf.forall loops created by split reduction by folding into workgroup mapping loop">
290294
];
291295
}
292296

@@ -382,6 +386,11 @@ def FoldReshapeIntoInterfaceTensorPass :
382386
let summary = "Folds reshape operations into the interface bindings.";
383387
}
384388

389+
def FoldSplitReductionAndWorkgroupMappingLoopsPass :
390+
Pass<"iree-codegen-fold-split-reduction-and-workgroup-mapping-loops", ""> {
391+
let summary = "Folds `scf.forall` loops created by split reduction and workgroup mapping.";
392+
}
393+
385394
def FoldTensorExtractOpPass :
386395
Pass<"iree-codegen-fold-tensor-extract-op", ""> {
387396
let summary = "Fold `tensor.extract` operations prior to lowering to LLVM";

compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/Analysis/CallGraph.h"
2323
#include "mlir/Dialect/Affine/Utils.h"
2424
#include "mlir/Dialect/Arith/Utils/Utils.h"
25+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2526

2627
namespace mlir::iree_compiler {
2728

@@ -470,9 +471,12 @@ resolveSplitReduceForAll(RewriterBase &rewriter, FunctionOpInterface funcOp,
470471
SmallVector<scf::ForallOp> splitReductionForAllOps;
471472
funcOp.walk([&splitReductionForAllOps](scf::ForallOp forAllOp) {
472473
auto mapping = forAllOp.getMapping();
473-
if (!mapping || mapping->size() != 1 ||
474-
!isa<IREE::LinalgExt::SplitReductionMappingAttr>(
475-
mapping->getValue().front())) {
474+
if (!mapping) {
475+
return;
476+
}
477+
if (llvm::none_of(
478+
mapping->getValue(),
479+
llvm::IsaPred<IREE::LinalgExt::SplitReductionMappingAttr>)) {
476480
return;
477481
}
478482
splitReductionForAllOps.push_back(forAllOp);
@@ -619,6 +623,18 @@ getTargetFuncAttrs(IREE::Codegen::TranslationInfoAttr translationInfo) {
619623
void ReconcileTranslationInfoPass::runOnOperation() {
620624
auto variantOp = getOperation();
621625
auto innerModuleOp = variantOp.getInnerModule();
626+
MLIRContext *context = &getContext();
627+
628+
if (foldSplitReductionLoopIntoWorkgroupMappingLoop) {
629+
RewritePatternSet foldLoopPattern(context);
630+
populateFoldSplitReductionAndWorkgroupMappingLoops(foldLoopPattern);
631+
if (failed(
632+
applyPatternsGreedily(innerModuleOp, std::move(foldLoopPattern)))) {
633+
innerModuleOp.emitOpError(
634+
"failed to fold split-reduction loop and workgroup mapping loop");
635+
return signalPassFailure();
636+
}
637+
}
622638

623639
// Get the symbol table of the inner module to lookup exported functions.
624640
SymbolTable symbolTable(innerModuleOp);
@@ -638,14 +654,14 @@ void ReconcileTranslationInfoPass::runOnOperation() {
638654
// Skip external functions.
639655
continue;
640656
}
657+
641658
// Resolve workgroup distribution related `scf.forall` ops.
642659
if (failed(resolveWorkgroupForAll(rewriter, rootFuncOp, distributeAlong))) {
643660
variantOp.emitOpError(
644661
"failed to resolve workgroup distribution forall ops");
645662
return signalPassFailure();
646663
}
647664

648-
// Resolve split reduction distribution.
649665
if (failed(
650666
resolveSplitReduceForAll(rewriter, rootFuncOp, distributeAlong))) {
651667
variantOp.emitOpError("failed to resolve split reduction forall ops");

compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ iree_lit_test_suite(
5555
"fold_affine_min_in_distributed_loops.mlir",
5656
"fold_affine_min_of_block_id.mlir",
5757
"fold_reshape_into_interface_tensor.mlir",
58+
"fold_split_reduction_workgroup_mapping_loops.mlir",
5859
"fold_tensor_extract_op.mlir",
5960
"forop_canonicalization.mlir",
6061
"generic_vectorization.mlir",
@@ -92,7 +93,9 @@ iree_lit_test_suite(
9293
"propagate_dispatch_size_bounds.mlir",
9394
"propagate_reshapes_by_expansion.mlir",
9495
"reconcile_translation_info.mlir",
96+
"reconcile_translation_info_legacy_resolve_split_reduction.mlir",
9597
"reconcile_translation_info_linearize.mlir",
98+
"reconcile_translation_info_linearize_legacy_resolve_split_reduction.mlir",
9699
"reductions.mlir",
97100
"rematerialize_parallel_ops.mlir",
98101
"remove_dead_allocs.mlir",

compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ iree_lit_test_suite(
5050
"fold_affine_min_in_distributed_loops.mlir"
5151
"fold_affine_min_of_block_id.mlir"
5252
"fold_reshape_into_interface_tensor.mlir"
53+
"fold_split_reduction_workgroup_mapping_loops.mlir"
5354
"fold_tensor_extract_op.mlir"
5455
"forall_to_for.mlir"
5556
"forop_canonicalization.mlir"
@@ -88,7 +89,9 @@ iree_lit_test_suite(
8889
"propagate_dispatch_size_bounds.mlir"
8990
"propagate_reshapes_by_expansion.mlir"
9091
"reconcile_translation_info.mlir"
92+
"reconcile_translation_info_legacy_resolve_split_reduction.mlir"
9193
"reconcile_translation_info_linearize.mlir"
94+
"reconcile_translation_info_linearize_legacy_resolve_split_reduction.mlir"
9295
"reductions.mlir"
9396
"rematerialize_parallel_ops.mlir"
9497
"remove_dead_allocs.mlir"
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// RUN: iree-opt --iree-codegen-fold-split-reduction-and-workgroup-mapping-loops --split-input-file --mlir-print-local-scope --allow-unregistered-dialect %s | FileCheck %s
2+
3+
func.func @simple_example_1dmapping(%0 : index, %1 : index, %2 : index, %3 : index,
4+
%4 : index, %5 : index) {
5+
scf.forall (%arg0) = (%0) to (%1) step (%2) {
6+
"use1"(%arg0) : (index) -> ()
7+
scf.forall (%arg1) = (%3) to (%4) step (%5) {
8+
"use2"(%arg0, %arg1) : (index, index) -> ()
9+
} {mapping = [#iree_codegen.workgroup_mapping<x>]}
10+
} {mapping = [#iree_linalg_ext.split_reduction_mapping]}
11+
return
12+
}
13+
// CHECK: func @simple_example_1dmapping
14+
// CHECK-SAME: %[[SPLIT_LB:[a-zA-Z0-9_]+]]: index
15+
// CHECK-SAME: %[[SPLIT_UB:[a-zA-Z0-9_]+]]: index
16+
// CHECK-SAME: %[[SPLIT_STEP:[a-zA-Z0-9_]+]]: index
17+
// CHECK-SAME: %[[WG_LB:[a-zA-Z0-9_]+]]: index
18+
// CHECK-SAME: %[[WG_UB:[a-zA-Z0-9_]+]]: index
19+
// CHECK-SAME: %[[WG_STEP:[a-zA-Z0-9_]+]]: index
20+
// CHECK: scf.forall
21+
// CHECK-SAME: %[[IV0:[a-zA-Z0-9]+]]
22+
// CHECK-SAME: %[[IV1:[a-zA-Z0-9]+]]
23+
// CHECK-SAME: = (%[[SPLIT_LB]], %[[WG_LB]])
24+
// CHECK-SAME: to (%[[SPLIT_UB]], %[[WG_UB]])
25+
// CHECK-SAME: step (%[[SPLIT_STEP]], %[[WG_STEP]])
26+
// CHECK: "use1"(%[[IV0]])
27+
// CHECK: "use2"(%[[IV0]], %[[IV1]])
28+
// CHECK: mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]
29+
30+
// -----
31+
32+
func.func @simple_example_2dmapping(%0 : index, %1 : index, %2 : index, %3 : index,
33+
%4 : index) {
34+
scf.forall (%arg0) = (%0) to (%1) step (%2) {
35+
"use1"(%arg0) : (index) -> ()
36+
scf.forall (%arg1, %arg2) in (%3, %4) {
37+
"use2"(%arg0, %arg1, %arg2) : (index, index, index) -> ()
38+
} {mapping = [#iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
39+
} {mapping = [#iree_linalg_ext.split_reduction_mapping]}
40+
return
41+
}
42+
// CHECK: func @simple_example_2dmapping
43+
// CHECK-SAME: %[[SPLIT_LB:[a-zA-Z0-9_]+]]: index
44+
// CHECK-SAME: %[[SPLIT_UB:[a-zA-Z0-9_]+]]: index
45+
// CHECK-SAME: %[[SPLIT_STEP:[a-zA-Z0-9_]+]]: index
46+
// CHECK-SAME: %[[WG_UB0:[a-zA-Z0-9_]+]]: index
47+
// CHECK-SAME: %[[WG_UB1:[a-zA-Z0-9_]+]]: index
48+
// CHECK: scf.forall
49+
// CHECK-SAME: %[[IV0:[a-zA-Z0-9]+]]
50+
// CHECK-SAME: %[[IV1:[a-zA-Z0-9]+]]
51+
// CHECK-SAME: %[[IV2:[a-zA-Z0-9]+]]
52+
// CHECK-SAME: = (%[[SPLIT_LB]], 0, 0)
53+
// CHECK-SAME: to (%[[SPLIT_UB]], %[[WG_UB0]], %[[WG_UB1]])
54+
// CHECK-SAME: step (%[[SPLIT_STEP]], 1, 1)
55+
// CHECK: "use1"(%[[IV0]])
56+
// CHECK: "use2"(%[[IV0]], %[[IV1]], %[[IV2]])
57+
// CHECK: mapping = [#iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]
58+
59+
// -----
60+
61+
func.func @simple_example_3dmapping(%0 : index, %1 : index, %2 : index, %3 : index,
62+
%4 : index, %5 : index) {
63+
scf.forall (%arg0) = (%0) to (%1) step (%2) {
64+
"use1"(%arg0) : (index) -> ()
65+
scf.forall (%arg1, %arg2, %arg3) in (%3, %4, %5) {
66+
"use2"(%arg1, %arg2, %arg3) : (index, index, index) -> ()
67+
} {mapping = [#iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
68+
} {mapping = [#iree_linalg_ext.split_reduction_mapping]}
69+
return
70+
}
71+
// CHECK: func @simple_example_3dmapping
72+
// CHECK-SAME: %[[SPLIT_LB:[a-zA-Z0-9_]+]]: index
73+
// CHECK-SAME: %[[SPLIT_UB:[a-zA-Z0-9_]+]]: index
74+
// CHECK-SAME: %[[SPLIT_STEP:[a-zA-Z0-9_]+]]: index
75+
// CHECK-SAME: %[[ORIG_UB0:[a-zA-Z0-9_]+]]: index
76+
// CHECK-SAME: %[[ORIG_UB1:[a-zA-Z0-9_]+]]: index
77+
// CHECK-SAME: %[[ORIG_UB2:[a-zA-Z0-9_]+]]: index
78+
// CHECK: scf.forall
79+
// CHECK-SAME: %[[IV0:[a-zA-Z0-9]+]]
80+
// CHECK-SAME: %[[IV1:[a-zA-Z0-9]+]]
81+
// CHECK-SAME: %[[IV2:[a-zA-Z0-9]+]]
82+
// CHECK-SAME: %[[IV3:[a-zA-Z0-9]+]]
83+
// CHECK-SAME: = (%[[SPLIT_LB]], 0, 0, 0)
84+
// CHECK-SAME: to (%[[SPLIT_UB]], %[[ORIG_UB0]], %[[ORIG_UB1]], %[[ORIG_UB2]])
85+
// CHECK-SAME: step (%[[SPLIT_STEP]], 1, 1, 1)
86+
// CHECK: "use1"(%[[IV0]])
87+
// CHECK: "use2"(%[[IV1]], %[[IV2]], %[[IV3]])
88+
// CHECK: mapping = [#iree_codegen.workgroup_mapping<z:1>, #iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]

compiler/src/iree/compiler/Codegen/Common/test/reconcile_translation_info.mlir

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -722,10 +722,9 @@ hal.executable private @split_reduction_executable {
722722
}
723723
}
724724
}
725-
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2, s3, s4, s5] -> (((-s3 + s4) ceildiv s5) * ((s1 * s2) * s0))>
725+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2, s3, s4, s5] -> (((s1 * s2) * s0) * ((-s3 + s4) ceildiv s5))>
726726
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((-s0 + s1) ceildiv s2)>
727-
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2, s3] -> (s0 floordiv ((-s1 + s2) ceildiv s3))>
728-
// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
727+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
729728
// CHECK: @split_reduction_variant
730729
// CHECK: hal.executable.export
731730
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
@@ -746,11 +745,62 @@ hal.executable private @split_reduction_executable {
746745
// CHECK-DAG: %[[ORIG_UB2:.+]] = hal.interface.constant.load {{.+}} ordinal(5)
747746
// CHECK-DAG: %[[SPLIT_NPROCS:.+]] = affine.apply #[[MAP1]]()[%[[SPLIT_LB]], %[[SPLIT_UB]], %[[SPLIT_STEP]]]
748747
// CHECK-DAG: %[[IDX:.+]] = hal.interface.workgroup.id[0]
748+
// CHECK: %[[DELINEARIZE:.+]]:4 = affine.delinearize_index %[[IDX]] into (%[[SPLIT_NPROCS]], %[[ORIG_UB0]], %[[ORIG_UB1]], %[[ORIG_UB2]])
749+
// CHECK: %[[SPLITIVREPLACEMENT:.+]] = affine.apply #[[MAP2]]()[%[[DELINEARIZE]]#0, %[[SPLIT_STEP]]]
750+
// CHECK: "use1"(%[[SPLITIVREPLACEMENT]])
751+
// CHECK: "use2"(%[[DELINEARIZE]]#1, %[[DELINEARIZE]]#2, %[[DELINEARIZE]]#3)
752+
753+
// -----
754+
755+
// Check that having just the split reduction loop gets resolved.
756+
757+
#pipeline_layout = #hal.pipeline.layout<constants = 3, bindings = [
758+
#hal.pipeline.binding<storage_buffer, "ReadOnly">,
759+
#hal.pipeline.binding<storage_buffer>]>
760+
hal.executable private @only_split_reduction_executable {
761+
hal.executable.variant public @only_split_reduction_variant target(#hal.executable.target<"", "", {}>) {
762+
hal.executable.export public @only_split_reduction layout(#pipeline_layout) count(
763+
%arg0: !hal.device, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index) {
764+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice %arg1, %arg2, %arg3
765+
%return_x, %return_y, %return_z =
766+
iree_tensor_ext.dispatch.workgroup_count_split_reduction_modifier(%x, %y, %z), %arg1, %arg2, %arg3
767+
hal.return %return_x, %return_y, %return_z : index, index, index
768+
}
769+
builtin.module {
770+
func.func @only_split_reduction() {
771+
%cst0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
772+
%cst1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
773+
%cst2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
774+
%0 = iree_tensor_ext.dispatch.workload.ordinal %cst0, 0 : index
775+
%1 = iree_tensor_ext.dispatch.workload.ordinal %cst1, 1 : index
776+
%2 = iree_tensor_ext.dispatch.workload.ordinal %cst2, 2 : index
777+
scf.forall (%arg0) = (%0) to (%1) step (%2) {
778+
"use1"(%arg0) : (index) -> ()
779+
} {mapping = [#iree_linalg_ext.split_reduction_mapping]}
780+
return
781+
}
782+
}
783+
}
784+
}
785+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((-s0 + s1) ceildiv s2)>
786+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2, s3] -> (s0 floordiv ((-s1 + s2) ceildiv s3))>
787+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
788+
// CHECK: @only_split_reduction_variant
789+
// CHECK: hal.executable.export
790+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
791+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
792+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
793+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
794+
// CHECK-DAG: %[[NUMWORKGROUPSX:.+]] = affine.apply #[[MAP0]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]]]
795+
// CHECK: hal.return %[[NUMWORKGROUPSX]], %[[C1]], %[[C1]]
796+
// CHECK: func @only_split_reduction
797+
// CHECK-DAG: %[[SPLIT_LB:.+]] = hal.interface.constant.load {{.+}} ordinal(0)
798+
// CHECK-DAG: %[[SPLIT_UB:.+]] = hal.interface.constant.load {{.+}} ordinal(1)
799+
// CHECK-DAG: %[[SPLIT_STEP:.+]] = hal.interface.constant.load {{.+}} ordinal(2)
800+
// CHECK-DAG: %[[SPLIT_NPROCS:.+]] = affine.apply #[[MAP0]]()[%[[SPLIT_LB]], %[[SPLIT_UB]], %[[SPLIT_STEP]]]
801+
// CHECK-DAG: %[[IDX:.+]] = hal.interface.workgroup.id[0]
749802
// CHECK-DAG: %[[COUNTX:.+]] = hal.interface.workgroup.count[0]
750-
// CHECK-DAG: %[[ORIG_COUNTZ:.+]] = affine.apply #[[MAP2]]()[%[[COUNTX]], %[[SPLIT_LB]], %[[SPLIT_UB]], %[[SPLIT_STEP]]]
751-
// CHECK: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[IDX]] into (%[[SPLIT_NPROCS]], %[[ORIG_COUNTZ]]
752-
// CHECK: %[[SPLITIVREPLACEMENT:.+]] = affine.apply #[[MAP3]]()[%[[DELINEARIZE]]#0, %[[SPLIT_STEP]]]
803+
// CHECK-DAG: %[[ORIGCOUNTX:.+]] = affine.apply #[[MAP1]]()[%[[COUNTX]], %[[SPLIT_LB]], %[[SPLIT_UB]], %[[SPLIT_STEP]]]
804+
// CHECK: %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[IDX]] into (%[[SPLIT_NPROCS]], %[[ORIGCOUNTX]])
805+
// CHECK: %[[SPLITIVREPLACEMENT:.+]] = affine.apply #[[MAP2]]()[%[[DELINEARIZE]]#0, %[[SPLIT_STEP]]]
753806
// CHECK: "use1"(%[[SPLITIVREPLACEMENT]])
754-
// CHECK: %[[OTHERIVREPLACEMENTS:.+]]:3 = affine.delinearize_index %[[DELINEARIZE]]#1
755-
// CHECK-SAME: into (%[[ORIG_UB0]], %[[ORIG_UB1]], %[[ORIG_UB2]]
756-
// CHECK: "use2"(%[[OTHERIVREPLACEMENTS]]#0, %[[OTHERIVREPLACEMENTS]]#1, %[[OTHERIVREPLACEMENTS]]#2)

0 commit comments

Comments
 (0)