Skip to content

Commit bd524f3

Browse files
authored
[Encoding] Implement isIdentityLayout interface method. (iree-org#20286)
The revision introduces `isIdentityLayout` interface method to SerializableEncodingAttrInterface, and implements the method for `EncodingAttr` and `PadEncodingLayoutAttr` attributes. It also fixes a bug in `EncodingAttr` builders, where we should use NULL when maps are not present. --------- Signed-off-by: hanhanW <[email protected]>
1 parent e122dcc commit bd524f3

File tree

6 files changed

+155
-4
lines changed

6 files changed

+155
-4
lines changed

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,14 @@ EncodingAttr EncodingAttr::get(MLIRContext *ctx, int64_t operandIndex,
6969
ArrayRef<Attribute> layouts) {
7070
Builder b(ctx);
7171
auto opTypeAttr = EncodingOpTypeAttr::get(ctx, opType);
72+
auto mapsAttr = maps.empty() ? ArrayAttr() : b.getAffineMapArrayAttr(maps);
7273
auto roundDimsToAttr = roundDimsTo.empty()
7374
? DenseI64ArrayAttr()
7475
: b.getDenseI64ArrayAttr(roundDimsTo);
7576
auto layoutsAttr = layouts.empty() ? ArrayAttr() : b.getArrayAttr(layouts);
7677
return get(ctx, b.getIndexAttr(operandIndex), opTypeAttr,
77-
b.getTypeArrayAttr(elemTypes), b.getAffineMapArrayAttr(maps),
78-
roundDimsToAttr, layoutsAttr);
78+
b.getTypeArrayAttr(elemTypes), mapsAttr, roundDimsToAttr,
79+
layoutsAttr);
7980
}
8081

8182
LogicalResult
@@ -193,6 +194,25 @@ EncodingAttr::cloneWithNewOperandIndexingMap(AffineMap newIndexingMap) {
193194

194195
bool EncodingAttr::isSerialized() const { return getLayouts() ? true : false; }
195196

197+
bool EncodingAttr::isIdentityLayout() const {
198+
if (!isSerialized()) {
199+
return false;
200+
}
201+
ArrayAttr layoutsAttr = getLayouts();
202+
if (!llvm::all_of(layoutsAttr.getValue(),
203+
llvm::IsaPred<SerializableEncodingAttrInterface>)) {
204+
return false;
205+
}
206+
return llvm::all_of(layoutsAttr.getValue(), [](Attribute attr) {
207+
auto serializableAttr =
208+
llvm::dyn_cast<SerializableEncodingAttrInterface>(attr);
209+
if (!serializableAttr) {
210+
return false;
211+
}
212+
return serializableAttr.isIdentityLayout();
213+
});
214+
}
215+
196216
Attribute EncodingAttr::cloneWithLayouts(ArrayRef<Attribute> layouts) const {
197217
MLIRContext *ctx = getContext();
198218
return get(ctx, getOperandIndex(), getOpType(), getElementTypes(),
@@ -338,6 +358,11 @@ PadEncodingLayoutAttr PadEncodingLayoutAttr::getIdentityAttr(MLIRContext *ctx,
338358
return get(ctx, zeros);
339359
}
340360

361+
bool PadEncodingLayoutAttr::isIdentityLayout() const {
362+
ArrayRef<int32_t> padding = getPadding().asArrayRef();
363+
return llvm::all_of(padding, [](int32_t val) { return val == 0; });
364+
}
365+
341366
Value PadEncodingLayoutAttr::calculateStorageSizeInBytes(
342367
Location loc, OpBuilder &builder, RankedTensorType type,
343368
ValueRange dynamicDims) const {

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def EncodingAttr :
5656
IREEEncoding_Attr<"Encoding", [
5757
DeclareAttrInterfaceMethods<IREEEncoding_SerializableEncodingAttrInterface, [
5858
"isSerialized",
59+
"isIdentityLayout",
5960
"cloneWithLayouts",
6061
"calculateStorageSizeInBytes",
6162
]>
@@ -149,8 +150,10 @@ def EncodingAttr :
149150
//===---------------------------------------------------------------------===//
150151

151152
def PadEncodingLayoutAttr : IREEEncoding_Attr<"PadEncodingLayout", [
152-
DeclareAttrInterfaceMethods<IREEEncoding_SerializableEncodingAttrInterface,
153-
["calculateStorageSizeInBytes"]>
153+
DeclareAttrInterfaceMethods<IREEEncoding_SerializableEncodingAttrInterface, [
154+
"calculateStorageSizeInBytes",
155+
"isIdentityLayout",
156+
]>
154157
]> {
155158
let mnemonic = "pad_encoding_layout";
156159
let assemblyFormat = "`<` $padding `>`";

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingInterfaces.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,19 @@ def IREEEncoding_SerializableEncodingAttrInterface :
121121
return true;
122122
}]
123123
>,
124+
InterfaceMethod<
125+
/*desc=*/[{
126+
Returns true if the encoding is known as an identity layout.
127+
}],
128+
/*retTy=*/"bool",
129+
/*methodName=*/"isIdentityLayout",
130+
/*args=*/(ins
131+
),
132+
/*methodBody=*/"",
133+
/*defaultImplementation=*/[{
134+
return false;
135+
}]
136+
>,
124137
InterfaceMethod<
125138
/*desc=*/[{
126139
Creates an encoding with a new layout list. It is valid to drop any
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2025 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
load("//build_tools/bazel:build_defs.oss.bzl", "iree_compiler_cc_test")
8+
9+
package(
10+
features = ["layering_check"],
11+
licenses = ["notice"], # Apache 2.0
12+
)
13+
14+
iree_compiler_cc_test(
15+
name = "EncodingAttrTest",
16+
testonly = True,
17+
srcs = ["EncodingAttrTest.cpp"],
18+
deps = [
19+
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
20+
"//compiler/src/iree/testing:gtest_main",
21+
"@com_google_googletest//:gtest",
22+
"@llvm-project//mlir:DialectUtils",
23+
"@llvm-project//mlir:IR",
24+
],
25+
)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
################################################################################
2+
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
3+
# compiler/src/iree/compiler/Dialect/Encoding/IR/unittests/BUILD.bazel #
4+
# #
5+
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
6+
# CMake-only content. #
7+
# #
8+
# To disable autogeneration for this file entirely, delete this header. #
9+
################################################################################
10+
11+
iree_add_all_subdirs()
12+
13+
iree_cc_test(
14+
NAME
15+
EncodingAttrTest
16+
SRCS
17+
"EncodingAttrTest.cpp"
18+
DEPS
19+
MLIRIR
20+
gmock
21+
gtest
22+
iree::compiler::Dialect::Encoding::IR
23+
iree::testing::gtest_main
24+
)
25+
26+
### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#include <gmock/gmock.h>
8+
#include <gtest/gtest.h>
9+
10+
#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
11+
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
12+
13+
namespace mlir::iree_compiler::IREE::Encoding {
14+
namespace {
15+
16+
class EncodingAttrsTest : public ::testing::Test {
17+
protected:
18+
EncodingAttrsTest() {
19+
reg.insert<IREEEncodingDialect>();
20+
ctx.appendDialectRegistry(reg);
21+
ctx.loadAllAvailableDialects();
22+
}
23+
~EncodingAttrsTest() override {}
24+
25+
MLIRContext *getContext() { return &ctx; }
26+
27+
private:
28+
MLIRContext ctx;
29+
DialectRegistry reg;
30+
};
31+
32+
TEST_F(EncodingAttrsTest, EncodingAttr) {
33+
MLIRContext *ctx = getContext();
34+
Builder builder(ctx);
35+
SmallVector<Type> elemTypes(3, builder.getF32Type());
36+
auto attr = cast<SerializableEncodingAttrInterface>(EncodingAttr::get(
37+
ctx, /*operandIndex=*/0, EncodingOpType::matmul, elemTypes));
38+
EXPECT_FALSE(attr.isIdentityLayout());
39+
40+
attr = cast<SerializableEncodingAttrInterface>(attr.cloneWithLayouts(
41+
PadEncodingLayoutAttr::getIdentityAttr(ctx, /*rank=*/2)));
42+
EXPECT_TRUE(attr.isIdentityLayout());
43+
}
44+
45+
TEST_F(EncodingAttrsTest, PadEncodingLayoutAttr) {
46+
MLIRContext *ctx = getContext();
47+
auto zeroPaddingAttr =
48+
PadEncodingLayoutAttr::getIdentityAttr(ctx, /*rank=*/2);
49+
EXPECT_TRUE(cast<SerializableEncodingAttrInterface>(zeroPaddingAttr)
50+
.isIdentityLayout());
51+
52+
SmallVector<int32_t> paddings = {4, 2};
53+
auto nonZeroPaddingAttr = PadEncodingLayoutAttr::get(ctx, paddings);
54+
EXPECT_FALSE(cast<SerializableEncodingAttrInterface>(nonZeroPaddingAttr)
55+
.isIdentityLayout());
56+
}
57+
58+
} // namespace
59+
} // namespace mlir::iree_compiler::IREE::Encoding

0 commit comments

Comments
 (0)