Skip to content

Commit ea4f95e

Browse files
Groverksskeshavvinayak01
authored andcommitted
[DispatchCreation] Drop unit dims for flow.parameter.named (iree-org#21687)
The input shouldn't have stream.parameter.named attributes at input level after iree-org#17303 , this patch makes sure we handle both attributes similarily in DispatchCreation, so that we see no regressions when frontends switch to this. Signed-off-by: keshavvinayak01 <[email protected]>
1 parent 4cb802b commit ea4f95e

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14+
#include "iree/compiler/Dialect/Flow/IR/FlowTypes.h"
1415
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
1516
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
1617
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
@@ -252,11 +253,19 @@ foldUnitDimsOnGlobal(IRRewriter &rewriter, IREE::Util::GlobalOpInterface global,
252253
newGlobalType);
253254
})
254255
.Case<IREE::Stream::NamedParameterAttr>(
256+
// TODO: Remove this case once frontends have caught up, we should
257+
// not have stream.parameter.named at this level.
255258
[&](IREE::Stream::NamedParameterAttr attr) {
256259
return IREE::Stream::NamedParameterAttr::get(
257260
rewriter.getContext(), newGlobalType, attr.getScope(),
258261
attr.getKey(), attr.getConfig());
259262
})
263+
.Case<IREE::Flow::NamedParameterAttr>(
264+
[&](IREE::Flow::NamedParameterAttr attr) {
265+
return IREE::Flow::NamedParameterAttr::get(
266+
rewriter.getContext(), newGlobalType, attr.getScope(),
267+
attr.getKey(), attr.getConfig());
268+
})
260269
.Default([&](Attribute) { return nullptr; });
261270
if (!newInitialValue) {
262271
return success();

compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,21 @@ module @fold_stream_parameter {
109109

110110
// -----
111111

112+
module @fold_flow_parameter {
113+
util.global private mutable @global = #flow.parameter.named<"module"::"global"> : tensor<1x1x10xf32>
114+
util.func public @fold_flow_parameter() -> tensor<1x1x10xf32> {
115+
%global = util.global.load @global : tensor<1x1x10xf32>
116+
util.return %global : tensor<1x1x10xf32>
117+
}
118+
}
119+
120+
// CHECK: module @fold_flow_parameter
121+
// CHECK: util.global private mutable @[[GLOBAL:.+]] = #flow.parameter.named<"module"::"global"> : tensor<10xf32>
122+
// CHECK: util.func public @fold_flow_parameter
123+
// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<10xf32>
124+
125+
// -----
126+
112127
util.func public @scatter(%arg0 : tensor<4xi64>, %arg1 : tensor<4x1xi32>, %arg2 : tensor<4xi64>) -> tensor<4xi64> {
113128
%0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false) ins(%arg0, %arg1: tensor<4xi64>, tensor<4x1xi32>) outs(%arg2 : tensor<4xi64>) {
114129
^bb0(%arg3: i64, %arg4: i64):

0 commit comments

Comments
 (0)