Skip to content

Commit 95d7b17

Browse files
sakupan102pstarkcdpr
authored andcommitted
[LinalgExt][NFC] Split the op definition between pure ops and LinalgExt ops (iree-org#22368)
Split LinalgExtOp into PureOp and LinalgExtOp. This enables automatically attaching the right interfaces by simply including the generated .cpp.inc files, preventing missed interface attachments. Closes iree-org#20862 --------- Signed-off-by: Ryutaro Okada <[email protected]>
1 parent 34f36ea commit 95d7b17

File tree

10 files changed

+146
-108
lines changed

10 files changed

+146
-108
lines changed

compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp

Lines changed: 16 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,15 @@ struct LinalgExtOpInterface
331331
}
332332
};
333333

334+
template <typename... Ops>
335+
struct LinalgExtOpInterfaceHelper {
336+
static void registerOpInterface(MLIRContext *context) {
337+
(void)std::initializer_list<int>{
338+
0, (Ops::template attachInterface<LinalgExtOpInterface<Ops>>(*context),
339+
0)...};
340+
}
341+
};
342+
334343
/// Returns the buffers of the source and destination for pack and unpack ops.
335344
/// Returns a failure if the buffers can not be found.
336345
template <typename OpTy>
@@ -686,37 +695,13 @@ void registerBufferizationInterfaces(DialectRegistry &registry) {
686695
StoreToBufferOpInterface, StoreToBufferOpSubsetInterface,
687696
StoreToBufferOpSubsetInsertionInterface>(*ctx);
688697
});
689-
registry.addExtension(+[](MLIRContext *ctx,
690-
IREE::LinalgExt::IREELinalgExtDialect *dialect) {
691-
IREE::LinalgExt::ArgCompareOp::attachInterface<
692-
LinalgExtOpInterface<IREE::LinalgExt::ArgCompareOp>>(*ctx);
693-
IREE::LinalgExt::FftOp::attachInterface<
694-
LinalgExtOpInterface<IREE::LinalgExt::FftOp>>(*ctx);
695-
IREE::LinalgExt::PackOp::attachInterface<
696-
LinalgExtOpInterface<IREE::LinalgExt::PackOp>>(*ctx);
697-
IREE::LinalgExt::UnPackOp::attachInterface<
698-
LinalgExtOpInterface<IREE::LinalgExt::UnPackOp>>(*ctx);
699-
IREE::LinalgExt::ScanOp::attachInterface<
700-
LinalgExtOpInterface<IREE::LinalgExt::ScanOp>>(*ctx);
701-
IREE::LinalgExt::ScatterOp::attachInterface<
702-
LinalgExtOpInterface<IREE::LinalgExt::ScatterOp>>(*ctx);
703-
IREE::LinalgExt::GatherOp::attachInterface<
704-
LinalgExtOpInterface<IREE::LinalgExt::GatherOp>>(*ctx);
705-
IREE::LinalgExt::SortOp::attachInterface<
706-
LinalgExtOpInterface<IREE::LinalgExt::SortOp>>(*ctx);
707-
IREE::LinalgExt::TopkOp::attachInterface<
708-
LinalgExtOpInterface<IREE::LinalgExt::TopkOp>>(*ctx);
709-
IREE::LinalgExt::WinogradInputTransformOp::attachInterface<
710-
LinalgExtOpInterface<IREE::LinalgExt::WinogradInputTransformOp>>(*ctx);
711-
IREE::LinalgExt::WinogradFilterTransformOp::attachInterface<
712-
LinalgExtOpInterface<IREE::LinalgExt::WinogradFilterTransformOp>>(*ctx);
713-
IREE::LinalgExt::WinogradOutputTransformOp::attachInterface<
714-
LinalgExtOpInterface<IREE::LinalgExt::WinogradOutputTransformOp>>(*ctx);
715-
IREE::LinalgExt::AttentionOp::attachInterface<
716-
LinalgExtOpInterface<IREE::LinalgExt::AttentionOp>>(*ctx);
717-
IREE::LinalgExt::MapScatterOp::attachInterface<
718-
LinalgExtOpInterface<IREE::LinalgExt::MapScatterOp>>(*ctx);
719-
});
698+
registry.addExtension(
699+
+[](MLIRContext *ctx, IREE::LinalgExt::IREELinalgExtDialect *dialect) {
700+
LinalgExtOpInterfaceHelper<
701+
#define GET_OP_LIST
702+
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp.inc"
703+
>::registerOpInterface(ctx);
704+
});
720705
registry.insert<linalg::LinalgDialect>();
721706
registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
722707
linalg::PackOp::attachInterface<PackUnPackOpInterface<linalg::PackOp>>(

compiler/src/iree/compiler/Dialect/LinalgExt/IR/BUILD.bazel

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ iree_td_library(
2020
"LinalgExtBase.td",
2121
"LinalgExtInterfaces.td",
2222
"LinalgExtOps.td",
23+
"LinalgExtPureOps.td",
2324
],
2425
include = ["*.td"],
2526
),
@@ -52,7 +53,7 @@ iree_compiler_cc_library(
5253
"LinalgExtInterfaces.cpp",
5354
"LinalgExtInterfaces.cpp.inc",
5455
"LinalgExtOps.cpp",
55-
"LinalgExtOps.cpp.inc",
56+
"LinalgExtPureOps.cpp.inc",
5657
"LinalgExtTypes.cpp.inc",
5758
"TilingInterfaceImpl.cpp",
5859
],
@@ -63,12 +64,17 @@ iree_compiler_cc_library(
6364
"LinalgExtInterfaces.h",
6465
"LinalgExtInterfaces.h.inc",
6566
"LinalgExtOps.h",
66-
"LinalgExtOps.h.inc",
67+
"LinalgExtPureOps.h.inc",
6768
"LinalgExtTypes.h.inc",
6869
],
70+
textual_hdrs = [
71+
"LinalgExtOps.cpp.inc",
72+
"LinalgExtOps.h.inc",
73+
],
6974
deps = [
7075
":LinalgExtInterfacesIncGen",
7176
":LinalgExtOpsIncGen",
77+
":LinalgExtPureOpsIncGen",
7278
":LinalgExtTypesGen",
7379
"//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
7480
"@llvm-project//llvm:Support",
@@ -117,6 +123,25 @@ iree_gentbl_cc_library(
117123
deps = [":td_files"],
118124
)
119125

126+
iree_gentbl_cc_library(
127+
name = "LinalgExtPureOpsIncGen",
128+
tbl_outs = [
129+
(
130+
["--gen-op-decls"],
131+
"LinalgExtPureOps.h.inc",
132+
),
133+
(
134+
["--gen-op-defs"],
135+
"LinalgExtPureOps.cpp.inc",
136+
),
137+
],
138+
tblgen = "@llvm-project//mlir:mlir-tblgen",
139+
td_file = "LinalgExtPureOps.td",
140+
deps = [
141+
":td_files",
142+
],
143+
)
144+
120145
iree_gentbl_cc_library(
121146
name = "LinalgExtOpsIncGen",
122147
tbl_outs = [

compiler/src/iree/compiler/Dialect/LinalgExt/IR/CMakeLists.txt

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@ iree_cc_library(
2020
"LinalgExtInterfaces.h"
2121
"LinalgExtInterfaces.h.inc"
2222
"LinalgExtOps.h"
23-
"LinalgExtOps.h.inc"
23+
"LinalgExtPureOps.h.inc"
2424
"LinalgExtTypes.h.inc"
25+
TEXTUAL_HDRS
26+
"LinalgExtOps.cpp.inc"
27+
"LinalgExtOps.h.inc"
2528
SRCS
2629
"AggregatedOpInterfaceImpl.cpp"
2730
"LinalgExtAttrs.cpp.inc"
@@ -30,12 +33,13 @@ iree_cc_library(
3033
"LinalgExtInterfaces.cpp"
3134
"LinalgExtInterfaces.cpp.inc"
3235
"LinalgExtOps.cpp"
33-
"LinalgExtOps.cpp.inc"
36+
"LinalgExtPureOps.cpp.inc"
3437
"LinalgExtTypes.cpp.inc"
3538
"TilingInterfaceImpl.cpp"
3639
DEPS
3740
::LinalgExtInterfacesIncGen
3841
::LinalgExtOpsIncGen
42+
::LinalgExtPureOpsIncGen
3943
::LinalgExtTypesGen
4044
LLVMSupport
4145
MLIRAffineDialect
@@ -75,6 +79,16 @@ iree_tablegen_library(
7579
--gen-op-interface-defs LinalgExtInterfaces.cpp.inc
7680
)
7781

82+
iree_tablegen_library(
83+
NAME
84+
LinalgExtPureOpsIncGen
85+
TD_FILE
86+
"LinalgExtPureOps.td"
87+
OUTS
88+
--gen-op-decls LinalgExtPureOps.h.inc
89+
--gen-op-defs LinalgExtPureOps.cpp.inc
90+
)
91+
7892
iree_tablegen_library(
7993
NAME
8094
LinalgExtOpsIncGen

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,12 @@ def SplitReductionMappingAttr :
8787
}];
8888
}
8989

90+
//===----------------------------------------------------------------------===//
91+
// Base Op class.
92+
//===----------------------------------------------------------------------===//
93+
94+
class IREELinalgExt_PureOp<string mnemonic, list<Trait> traits = []> :
95+
Op<IREELinalgExt_Dialect, mnemonic, traits> {
96+
}
97+
9098
#endif // IREE_DIALECT_LINALGEXT_BASE

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ void IREELinalgExtDialect::initialize() {
7272
addOperations<
7373
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp.inc"
7474
>();
75+
76+
#define GET_OP_LIST
77+
addOperations<
78+
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtPureOps.cpp.inc"
79+
>();
7580
}
7681

7782
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.cpp.inc"

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2749,4 +2749,7 @@ DEFINE_OP_GET_EFFECTS(CustomOp)
27492749
// clang-format off
27502750
#define GET_OP_CLASSES
27512751
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp.inc" // IWYU pragma: keep
2752+
2753+
#define GET_OP_CLASSES
2754+
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtPureOps.cpp.inc" // IWYU pragma: keep
27522755
// clang-format: on

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
#define GET_ATTRDEF_CLASSES
2626
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtAttrs.h.inc" // IWYU pragma: export
2727

28+
#define GET_OP_CLASSES
29+
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtPureOps.h.inc" // IWYU pragma: export
30+
2831
#define GET_OP_CLASSES
2932
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h.inc" // IWYU pragma: export
3033

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@ include "mlir/Interfaces/ViewLikeInterface.td"
2222
// Base class.
2323
//===----------------------------------------------------------------------===//
2424

25-
class IREELinalgExt_PureOp<string mnemonic, list<Trait> traits = []> :
26-
Op<IREELinalgExt_Dialect, mnemonic, traits> {
27-
}
28-
2925
class IREELinalgExt_Op<string mnemonic, list<Trait> traits = []> :
3026
IREELinalgExt_PureOp<mnemonic, !listconcat(traits,
3127
[AttrSizedOperandSegments,
@@ -38,50 +34,6 @@ class IREELinalgExt_Op<string mnemonic, list<Trait> traits = []> :
3834
code extraLinalgExtOpClassDeclaration = "";
3935
}
4036

41-
//===----------------------------------------------------------------------===//
42-
// Utility ops
43-
//===----------------------------------------------------------------------===//
44-
45-
def OpGroupUtilityOps : OpDocGroup {
46-
let summary = "Utility ops";
47-
let description = "";
48-
}
49-
50-
let opDocGroup = OpGroupUtilityOps in {
51-
52-
def IREELinalgExt_IndexOp : IREELinalgExt_PureOp<"index", [Pure]>,
53-
Arguments<(ins ConfinedAttr<I64Attr, [IntMinValue<0>]>:$dim)>,
54-
Results<(outs Index:$result)> {
55-
let summary = [{LinalgExt index operation.}];
56-
let description = [{
57-
This operation is a mirror of `linalg.index` operation and has the same
58-
semantics, except that `linalg.index` enforces that the parent op is a
59-
`LinalgOp`, and the `iree_linalg_ext.index` operation enforces that the
60-
parent op is a `IREE::LinalgExt::CustomOp`.
61-
}];
62-
63-
let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
64-
let hasVerifier = 1;
65-
}
66-
67-
def IREELinalgExt_YieldOp : IREELinalgExt_PureOp<"yield", [Pure, ReturnLike, Terminator]> {
68-
let summary = [{LinalgExt yield op.}];
69-
let description = [{
70-
`iree_linalg_ext.yield` is a special terminator operation for blocks inside
71-
regions in `iree_linalg_ext` ops.
72-
}];
73-
74-
let arguments = (ins Variadic<AnyType>:$operands);
75-
76-
let builders = [
77-
OpBuilder<(ins), [{ /* nothing to do */ }]>,
78-
];
79-
80-
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
81-
}
82-
83-
} // OpGroupUtilityOps
84-
8537
//===----------------------------------------------------------------------===//
8638
// Non-structured ops
8739
//===----------------------------------------------------------------------===//
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#ifndef IREE_DIALECT_LINALGEXT_PURE_OPS
8+
#define IREE_DIALECT_LINALGEXT_PURE_OPS
9+
10+
include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtBase.td"
11+
include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.td"
12+
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
13+
include "mlir/Interfaces/ControlFlowInterfaces.td"
14+
include "mlir/Interfaces/DestinationStyleOpInterface.td"
15+
include "mlir/Interfaces/InferTypeOpInterface.td"
16+
include "mlir/Interfaces/SideEffectInterfaces.td"
17+
include "mlir/Interfaces/TilingInterface.td"
18+
include "mlir/Interfaces/ViewLikeInterface.td"
19+
20+
//===----------------------------------------------------------------------===//
21+
// Utility ops
22+
//===----------------------------------------------------------------------===//
23+
24+
def OpGroupUtilityOps : OpDocGroup {
25+
let summary = "Utility ops";
26+
let description = "";
27+
}
28+
29+
let opDocGroup = OpGroupUtilityOps in {
30+
31+
def IREELinalgExt_IndexOp : IREELinalgExt_PureOp<"index", [Pure]>,
32+
Arguments<(ins ConfinedAttr<I64Attr, [IntMinValue<0>]>:$dim)>,
33+
Results<(outs Index:$result)> {
34+
let summary = [{LinalgExt index operation.}];
35+
let description = [{
36+
This operation is a mirror of `linalg.index` operation and has the same
37+
semantics, except that `linalg.index` enforces that the parent op is a
38+
`LinalgOp`, and the `iree_linalg_ext.index` operation enforces that the
39+
parent op is a `IREE::LinalgExt::CustomOp`.
40+
}];
41+
42+
let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
43+
let hasVerifier = 1;
44+
}
45+
46+
def IREELinalgExt_YieldOp : IREELinalgExt_PureOp<"yield", [Pure, ReturnLike, Terminator]> {
47+
let summary = [{LinalgExt yield op.}];
48+
let description = [{
49+
`iree_linalg_ext.yield` is a special terminator operation for blocks inside
50+
regions in `iree_linalg_ext` ops.
51+
}];
52+
53+
let arguments = (ins Variadic<AnyType>:$operands);
54+
55+
let builders = [
56+
OpBuilder<(ins), [{ /* nothing to do */ }]>,
57+
];
58+
59+
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
60+
}
61+
62+
} // OpGroupUtilityOps
63+
64+
#endif // IREE_DIALECT_LINALGEXT_PURE_OPS

compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -459,33 +459,12 @@ void registerUtilExternalModels(DialectRegistry &registry) {
459459
>::registerOpInterface(context);
460460
});
461461

462-
// TODO(matthias-springer): Use a helper instead of listing all ops. This is
463-
// tricky because LinalgExtOps.td includes YieldOp.
464462
registry.addExtension(+[](MLIRContext *context,
465463
IREE::LinalgExt::IREELinalgExtDialect *dialect) {
466-
IREE::LinalgExt::ScatterOp::attachInterface<
467-
LinalgOpTiedOpInterface<IREE::LinalgExt::ScatterOp>>(*context);
468-
IREE::LinalgExt::SortOp::attachInterface<
469-
LinalgOpTiedOpInterface<IREE::LinalgExt::SortOp>>(*context);
470-
IREE::LinalgExt::FftOp::attachInterface<
471-
LinalgOpTiedOpInterface<IREE::LinalgExt::FftOp>>(*context);
472-
IREE::LinalgExt::ScanOp::attachInterface<
473-
LinalgOpTiedOpInterface<IREE::LinalgExt::ScanOp>>(*context);
474-
IREE::LinalgExt::TopkOp::attachInterface<
475-
LinalgOpTiedOpInterface<IREE::LinalgExt::TopkOp>>(*context);
476-
IREE::LinalgExt::WinogradInputTransformOp::attachInterface<
477-
LinalgOpTiedOpInterface<IREE::LinalgExt::WinogradInputTransformOp>>(
478-
*context);
479-
IREE::LinalgExt::WinogradFilterTransformOp::attachInterface<
480-
LinalgOpTiedOpInterface<IREE::LinalgExt::WinogradFilterTransformOp>>(
481-
*context);
482-
IREE::LinalgExt::WinogradOutputTransformOp::attachInterface<
483-
LinalgOpTiedOpInterface<IREE::LinalgExt::WinogradOutputTransformOp>>(
484-
*context);
485-
IREE::LinalgExt::Im2colOp::attachInterface<
486-
LinalgOpTiedOpInterface<IREE::LinalgExt::Im2colOp>>(*context);
487-
IREE::LinalgExt::AttentionOp::attachInterface<
488-
LinalgOpTiedOpInterface<IREE::LinalgExt::AttentionOp>>(*context);
464+
LinalgOpTiedOpInterfaceHelper<
465+
#define GET_OP_LIST
466+
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp.inc"
467+
>::registerOpInterface(context);
489468
});
490469

491470
// Hoistable Op Interface registration.

0 commit comments

Comments
 (0)