Skip to content

Commit b4ba1e8

Browse files
CopilothanhanW
andauthored
[LinalgExt][NFC] Move AttrSizedOperandSegments from base class to individual ops (iree-org#22430)
The `AttrSizedOperandSegments` trait was inherited by all ops through `IREELinalgExt_Op`, but only ops with multiple variadic operands need it. This PR makes the trait explicit per-op for clarity and correctness. ## Changes **Base class (`IREELinalgExt_Op`)** - Removed `AttrSizedOperandSegments` trait **Ops with trait added (15 ops)** - Added `AttrSizedOperandSegments` to: ScatterOp, GatherOp, SortOp, FftOp, ScanOp, TopkOp, ArgCompareOp, ExpReductionOp, Im2colOp, PackOp, UnPackOp, WinogradInputTransformOp, WinogradFilterTransformOp, WinogradOutputTransformOp, CustomOp - All have multiple variadic operands (e.g., `Variadic<...>:$inputs, Variadic<...>:$outputs`) **Ops migrated from `IREELinalgExt_PureOp` to `IREELinalgExt_Op` (3 ops)** - MapScatterOp: No trait added (only 2 regular operands) - AttentionOp: No trait added (only has optional operand, not multiple variadic operands) - OnlineAttentionOp: No trait added (only has optional operand, not multiple variadic operands) - Removed duplicate trait declarations now inherited from base class **Note**: `AttrSizedOperandSegments` is only required for ops with multiple variadic operands or variadic operands in non-terminal positions. Optional operands alone do not require this trait. ```tablegen // Before class IREELinalgExt_Op<...> : IREELinalgExt_PureOp<..., [AttrSizedOperandSegments, ...]> // After class IREELinalgExt_Op<...> : IREELinalgExt_PureOp<..., [...]> def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter", [AttrSizedOperandSegments, ...]> ``` Fixes iree-org#22429 Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: hanhanW <[email protected]>
1 parent 34379e6 commit b4ba1e8

File tree

1 file changed

+31
-24
lines changed

1 file changed

+31
-24
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ include "mlir/Interfaces/ViewLikeInterface.td"
2424

2525
class IREELinalgExt_Op<string mnemonic, list<Trait> traits = []> :
2626
IREELinalgExt_PureOp<mnemonic, !listconcat(traits,
27-
[AttrSizedOperandSegments,
28-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
27+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
2928
DestinationStyleOpInterface, LinalgExtInterface,
3029
SingleBlockImplicitTerminator<"::mlir::iree_compiler::IREE::LinalgExt::YieldOp">
3130
])> {
@@ -46,7 +45,8 @@ def OpGroupNonStructuredOps : OpDocGroup {
4645
let opDocGroup = OpGroupNonStructuredOps in {
4746

4847
def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
49-
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
48+
[AttrSizedOperandSegments,
49+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
5050
DeclareOpInterfaceMethods<LinalgFusionInterface,
5151
["getIndexingMapsForResults", "getIndexingMapsForOperands",
5252
"getStaticLoopRanges"]>,
@@ -177,7 +177,8 @@ def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
177177
}
178178

179179
def IREELinalgExt_GatherOp : IREELinalgExt_Op<"gather",
180-
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
180+
[AttrSizedOperandSegments,
181+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
181182
DeclareOpInterfaceMethods<LinalgFusionInterface,
182183
["getIndexingMapsForResults", "getIndexingMapsForOperands",
183184
"getStaticLoopRanges"]>,
@@ -271,10 +272,8 @@ def IREELinalgExt_GatherOp : IREELinalgExt_Op<"gather",
271272
let hasCanonicalizer = 1;
272273
}
273274

274-
def IREELinalgExt_MapScatterOp : IREELinalgExt_PureOp<"map_scatter",
275-
[LinalgExtInterface, DestinationStyleOpInterface,
276-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
277-
DeclareOpInterfaceMethods<TilingInterface,
275+
def IREELinalgExt_MapScatterOp : IREELinalgExt_Op<"map_scatter",
276+
[DeclareOpInterfaceMethods<TilingInterface,
278277
["getIterationDomain",
279278
"getLoopIteratorTypes",
280279
"getResultTilePosition",
@@ -375,7 +374,8 @@ def IREELinalgExt_MapScatterOp : IREELinalgExt_PureOp<"map_scatter",
375374
}
376375

377376
def IREELinalgExt_SortOp : IREELinalgExt_Op<"sort",
378-
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
377+
[AttrSizedOperandSegments,
378+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
379379
DeclareOpInterfaceMethods<TilingInterface,
380380
["generateScalarImplementation",
381381
"getIterationDomain",
@@ -428,6 +428,7 @@ def IREELinalgExt_SortOp : IREELinalgExt_Op<"sort",
428428
}
429429

430430
def IREELinalgExt_FftOp : IREELinalgExt_Op<"fft", [
431+
AttrSizedOperandSegments,
431432
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
432433
DeclareOpInterfaceMethods<TilingInterface,
433434
["generateScalarImplementation",
@@ -498,7 +499,8 @@ def IREELinalgExt_FftOp : IREELinalgExt_Op<"fft", [
498499
}
499500

500501
def IREELinalgExt_ScanOp : IREELinalgExt_Op<"scan",
501-
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
502+
[AttrSizedOperandSegments,
503+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
502504
DeclareOpInterfaceMethods<TilingInterface,
503505
["generateScalarImplementation",
504506
"getIterationDomain",
@@ -554,6 +556,7 @@ def IREELinalgExt_ScanOp : IREELinalgExt_Op<"scan",
554556
}
555557

556558
def IREELinalgExt_TopkOp : IREELinalgExt_Op<"topk",[
559+
AttrSizedOperandSegments,
557560
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
558561
DeclareOpInterfaceMethods<LinalgExtInterface>,
559562
DeclareOpInterfaceMethods<TilingInterface,
@@ -631,6 +634,7 @@ def IREELinalgExt_TopkOp : IREELinalgExt_Op<"topk",[
631634
}
632635

633636
def IREELinalgExt_ArgCompareOp : IREELinalgExt_Op<"arg_compare", [
637+
AttrSizedOperandSegments,
634638
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
635639
DeclareOpInterfaceMethods<LinalgFusionInterface,
636640
["getIndexingMapsForResults", "getIndexingMapsForOperands",
@@ -760,10 +764,8 @@ def IREELinalgExt_ArgCompareOp : IREELinalgExt_Op<"arg_compare", [
760764
// Attention
761765
//===----------------------------------------------------------------------===//
762766

763-
def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention",
764-
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
765-
DestinationStyleOpInterface, LinalgExtInterface,
766-
DeclareOpInterfaceMethods<LinalgFusionInterface,
767+
def IREELinalgExt_AttentionOp : IREELinalgExt_Op<"attention",
768+
[DeclareOpInterfaceMethods<LinalgFusionInterface,
767769
["getIndexingMapsForResults", "getIndexingMapsForOperands",
768770
"getStaticLoopRanges"]>,
769771
DeclareOpInterfaceMethods<IndexingMapOpInterface, ["getMatchingIndexingMap"]>,
@@ -872,11 +874,8 @@ def IREELinalgExt_AttentionOp : IREELinalgExt_PureOp<"attention",
872874
// OnlineAttention
873875
//===----------------------------------------------------------------------===//
874876

875-
def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
876-
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
877-
DestinationStyleOpInterface,
878-
IndexingMapOpInterface,
879-
LinalgExtInterface,
877+
def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_Op<"online_attention",
878+
[IndexingMapOpInterface,
880879
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
881880
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
882881
DeclareOpInterfaceMethods<TilingInterface,
@@ -1008,7 +1007,8 @@ def IREELinalgExt_OnlineAttentionOp : IREELinalgExt_PureOp<"online_attention",
10081007
// ExpReduction
10091008
//===----------------------------------------------------------------------===//
10101009

1011-
def IREELinalgExt_ExpReductionOp : IREELinalgExt_Op<"exp_reduction"> {
1010+
def IREELinalgExt_ExpReductionOp : IREELinalgExt_Op<"exp_reduction",
1011+
[AttrSizedOperandSegments]> {
10121012
let summary = [{
10131013
A linalg.generic extension with support for exponential reduction.
10141014
}];
@@ -1163,7 +1163,8 @@ def IREELinalgExt_ExpReductionOp : IREELinalgExt_Op<"exp_reduction"> {
11631163
//===----------------------------------------------------------------------===//
11641164

11651165
def IREELinalgExt_Im2colOp : IREELinalgExt_Op<"im2col",
1166-
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
1166+
[AttrSizedOperandSegments,
1167+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
11671168
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
11681169
DeclareOpInterfaceMethods<TilingInterface,
11691170
["getIterationDomain",
@@ -1358,6 +1359,7 @@ def OpGroupDataTilingOps : OpDocGroup {
13581359
let opDocGroup = OpGroupDataTilingOps in {
13591360

13601361
def IREELinalgExt_PackOp : IREELinalgExt_Op<"pack", [
1362+
AttrSizedOperandSegments,
13611363
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
13621364
DeclareOpInterfaceMethods<LinalgExtInterface>,
13631365
DeclareOpInterfaceMethods<TilingInterface,
@@ -1521,6 +1523,7 @@ def IREELinalgExt_PackOp : IREELinalgExt_Op<"pack", [
15211523
}
15221524

15231525
def IREELinalgExt_UnPackOp : IREELinalgExt_Op<"unpack", [
1526+
AttrSizedOperandSegments,
15241527
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
15251528
DeclareOpInterfaceMethods<LinalgExtInterface>,
15261529
DeclareOpInterfaceMethods<TilingInterface,
@@ -1655,7 +1658,8 @@ def OpGroupWinogradOps : OpDocGroup {
16551658
let opDocGroup = OpGroupWinogradOps in {
16561659

16571660
def IREELinalgExt_WinogradInputTransformOp : IREELinalgExt_Op<"winograd.input_transform",
1658-
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
1661+
[AttrSizedOperandSegments,
1662+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
16591663
DeclareOpInterfaceMethods<TilingInterface,
16601664
["getIterationDomain",
16611665
"getLoopIteratorTypes",
@@ -1760,7 +1764,8 @@ def IREELinalgExt_WinogradInputTransformOp : IREELinalgExt_Op<"winograd.input_tr
17601764
}
17611765

17621766
def IREELinalgExt_WinogradFilterTransformOp : IREELinalgExt_Op<"winograd.filter_transform",
1763-
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
1767+
[AttrSizedOperandSegments,
1768+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
17641769
DeclareOpInterfaceMethods<TilingInterface,
17651770
["getIterationDomain",
17661771
"getLoopIteratorTypes",
@@ -1870,7 +1875,8 @@ def IREELinalgExt_WinogradFilterTransformOp : IREELinalgExt_Op<"winograd.filter_
18701875
}
18711876

18721877
def IREELinalgExt_WinogradOutputTransformOp : IREELinalgExt_Op<"winograd.output_transform",
1873-
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
1878+
[AttrSizedOperandSegments,
1879+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
18741880
DeclareOpInterfaceMethods<TilingInterface,
18751881
["getIterationDomain",
18761882
"getLoopIteratorTypes",
@@ -1985,6 +1991,7 @@ def IREELinalgExt_WinogradOutputTransformOp : IREELinalgExt_Op<"winograd.output_
19851991
//===---------------------------------------------------------------------===//
19861992

19871993
def IREELinalgExt_CustomOp : IREELinalgExt_Op<"custom_op", [
1994+
AttrSizedOperandSegments,
19881995
DeclareOpInterfaceMethods<AggregatedOpInterface, [
19891996
"decomposeOperation"]>,
19901997
DeclareOpInterfaceMethods<LinalgFusionInterface>,

0 commit comments

Comments
 (0)