Skip to content

Commit 7d0d249

Browse files
authored
Fix incorrect fusion check (#1956)
In AffixTuningParameters we will calculate the GemmFeatures based on the ops in the func, at which point we will then write this as an attribute to the func. The problem with this is that MIGraphX, and other users of rocMLIR, could query isModuleFusible at any point, potentially before AffixTuningParameters has run. This is perfectly legal behavior, but there was a bug in fusionUtils that expected the GemmFeatures to have already been written to the func when isModuleFusible was called. The fix for this that I came up with was to call rock::getFeatures directly on ReduceOps and teach getFeatures how to handle these ops. I also added in a new CAPI test as we seemingly did not have one testing fusibility of reduce ops.
1 parent 5647a4b commit 7d0d249

19 files changed

+306
-100
lines changed

mlir/include/mlir/Dialect/Rock/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ mlir_tablegen(RockAccelTuningParamAttrInterface.h.inc -gen-attr-interface-decls)
2020
mlir_tablegen(RockAccelTuningParamAttrInterface.cpp.inc -gen-attr-interface-defs)
2121
add_public_tablegen_target(MLIRRockAccelTuningParamAttrInterfaceIncGen)
2222

23+
add_mlir_interface(RockGemmFeaturesInterface)
2324
add_mlir_interface(RockGemmGemmWrapperInterface)
2425
add_mlir_interface(RockGemmWrapperInterface)
2526
add_mlir_interface(RockConvInterface)

mlir/include/mlir/Dialect/Rock/IR/GetRockInfo.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,6 @@ inline rock::GemmFeatures intersectGemmFeatures(rock::GemmFeatures a,
5252
// the architecture being used, and the type of the op.
5353
rock::GemmFeatures getFeatures(Operation *op);
5454

55-
// This function returns a boolean value if the underlying op has support for
56-
// an optional 'GemmFeatures' attribute
57-
bool opHasOptionalFeature(Operation *op);
58-
5955
} // End namespace rock
6056
} // End namespace mlir
6157
#endif // MLIR_DIALECT_ROCK_IR_GETROCKINFO_H

mlir/include/mlir/Dialect/Rock/IR/Rock.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ constexpr int64_t maxHardwareWorkgroupSize = 1024;
8282

8383
#include "mlir/Dialect/Rock/IR/RockAcceptingViewOpInterface.h"
8484
#include "mlir/Dialect/Rock/IR/RockConvInterface.h"
85+
#include "mlir/Dialect/Rock/IR/RockGemmFeaturesInterface.h"
8586
#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h"
8687
#include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.h"
8788
#include "mlir/Dialect/Rock/IR/RockWriterOpInterface.h"

mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define ROCK_ATTRS
1111

1212
include "mlir/Dialect/Rock/IR/RockBase.td"
13+
include "mlir/Dialect/Rock/IR/RockGemmFeaturesInterface.td"
1314
include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.td"
1415
include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.td"
1516
include "mlir/Dialect/Rock/IR/RockTuningParamAttrInterface.td"
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//===---- RockGemmFeaturesInterface.h - ops that wrap rock.gemm -*- C++ -*-===//
2+
//
3+
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM
4+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
// Copyright (c) 2025 Advanced Micro Devices INc.
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This file defines RockGemmFeaturesInterface, which abstracts Rock ops for
11+
// which we would expect to extract GemmFeatures for.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef MLIR_DIALECT_ROCK_IR_ROCKGEMMFEATURESINTERFACE_H
16+
#define MLIR_DIALECT_ROCK_IR_ROCKGEMMFEATURESINTERFACE_H
17+
18+
#include "mlir/Dialect/Rock/IR/GemmSize.h"
19+
#include "mlir/IR/OpDefinition.h"
20+
21+
#include "mlir/Dialect/Rock/IR/RockTypes.h"
22+
23+
#include "mlir/Dialect/Rock/IR/RockGemmFeaturesInterface.h.inc"
24+
25+
#endif // MLIR_DIALECT_ROCK_IR_ROCKGEMMFEATURESINTERFACE_H
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//===------------------ RockGemmFeaturesInterface.td ----------------------===//
2+
//
3+
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM
4+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
// Copyright (c) 2025 Advanced Micro Devices Inc.
8+
//===----------------------------------------------------------------------===//
9+
//
10+
// This file defines RockGemmFeaturesInterface, which abstracts Rock ops for
11+
// which we would expect to extract GemmFeatures for.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#ifndef ROCK_GEMM_FEATURES_INTERFACE
16+
#define ROCK_GEMM_FEATURES_INTERFACE
17+
18+
include "mlir/IR/OpBase.td"
19+
20+
def RockGemmFeaturesInterface : OpInterface<"RockGemmFeaturesInterface"> {
21+
let description = [{
22+
Interface to abstract away operations for which we need to extract
23+
information relating to GemmFeatures from.
24+
}];
25+
let cppNamespace = "::mlir::rock";
26+
27+
let methods = [InterfaceMethod<
28+
/*desc=*/[{
29+
Return the type from this op that is needed to calculate GemmFeatures
30+
}],
31+
/*retType=*/"SmallVector<::mlir::Type>",
32+
/*methodName=*/"getTypesForFeature",
33+
/*args=*/(ins),
34+
/*methodBody=*/"",
35+
/*defaultImplementation=*/"">];
36+
}
37+
38+
#endif // ROCK_GEMM_FEATURES_INTERFACE

mlir/include/mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.td

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,6 @@ def RockGemmGemmWrapperInterface : OpInterface<"RockGemmGemmWrapperInterface"> {
3030
let cppNamespace = "::mlir::rock";
3131

3232
let methods = [
33-
InterfaceMethod<
34-
/*desc=*/[{
35-
Return the type from this op that is needed to calculate GemmFeatures
36-
}],
37-
/*retType=*/"SmallVector<::mlir::Type>",
38-
/*methodName=*/"getTypesForFeature",
39-
/*args=*/(ins),
40-
/*methodBody=*/"return {$_op.getAType(), $_op.getCType()};",
41-
/*defaultImplementation=*/""
42-
>,
4333
InterfaceMethod<
4434
/*desc=*/[{
4535
Return the KernelType of this op

mlir/include/mlir/Dialect/Rock/IR/RockGemmWrapperInterface.td

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,6 @@ def RockGemmWrapperInterface : OpInterface<"RockGemmWrapperInterface"> {
2828
let cppNamespace = "::mlir::rock";
2929

3030
let methods = [
31-
InterfaceMethod<
32-
/*desc=*/[{
33-
Return the type from this op that is needed to calculate GemmFeatures
34-
}],
35-
/*retType=*/"SmallVector<::mlir::Type>",
36-
/*methodName=*/"getTypesForFeature",
37-
/*args=*/(ins),
38-
/*methodBody=*/"return {$_op.getAType()};",
39-
/*defaultImplementation=*/""
40-
>,
4131
InterfaceMethod<
4232
/*desc=*/[{
4333
Return the KernelType of this op

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 29 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
include "mlir/Dialect/Rock/IR/RockAttrDefs.td"
1717
include "mlir/Dialect/Rock/IR/RockConvInterface.td"
18+
include "mlir/Dialect/Rock/IR/RockGemmFeaturesInterface.td"
1819
include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.td"
1920
include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.td"
2021
include "mlir/Dialect/Rock/IR/RockAcceptingViewOpInterface.td"
@@ -61,11 +62,13 @@ class TensorOrMemRefOf<list<Type> allowedTypes> :
6162

6263
class IndexArrayLength<int n> : ConfinedAttr<IndexArrayAttr, [ArrayMinCount<n>]>;
6364

64-
class Rock_ConvOpBase<string mnemonic, list<Type> inputTypes=[F32, F16, BF16], list<Type> outputTypes=[F32, F16, BF16]> :
65-
Rock_Op<mnemonic, [DeclareOpInterfaceMethods<RockGemmWrapperInterface>,
66-
DeclareOpInterfaceMethods<RockConvInterface>,
67-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
68-
RockFusionRoot]>{
65+
class Rock_ConvOpBase<string mnemonic, list<Type> inputTypes = [F32, F16, BF16],
66+
list<Type> outputTypes = [F32, F16, BF16]>
67+
: Rock_Op<mnemonic, [DeclareOpInterfaceMethods<RockGemmWrapperInterface>,
68+
DeclareOpInterfaceMethods<RockConvInterface>,
69+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
70+
DeclareOpInterfaceMethods<RockGemmFeaturesInterface>,
71+
RockFusionRoot]> {
6972
dag commonConvArgs = (ins TensorOrMemRefOf<inputTypes>:$filter,
7073
TensorOrMemRefOf<inputTypes>:$input,
7174
TensorOrMemRefOf<outputTypes>:$output,
@@ -132,6 +135,7 @@ def Rock_ConvBwdWeightOp : Rock_ConvOpBase<"conv_bwd_weight">
132135

133136
def Rock_GemmOp
134137
: Rock_Op<"gemm", [DeclareOpInterfaceMethods<RockGemmWrapperInterface>,
138+
DeclareOpInterfaceMethods<RockGemmFeaturesInterface>,
135139
RockFusionRoot]>,
136140
Arguments<(
137141
ins Arg<TensorOrMemRefOf<GemmInputTypes>, "matrix A", [MemRead]>:$a,
@@ -187,12 +191,17 @@ def Rock_ReduceOp
187191
}];
188192
let extraClassDeclaration = [{
189193
::mlir::OpOperand* getOutArgument() { return &(*this)->getOpOperand(1); }
194+
195+
SmallVector<::mlir::Type> getTypesForFeature() {
196+
return {getIn().getType()};
197+
}
190198
}];
191199
}
192200
def Rock_AttentionOp
193201
: Rock_Op<
194202
"attention", [DeclareOpInterfaceMethods<RockGemmGemmWrapperInterface>,
195203
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
204+
DeclareOpInterfaceMethods<RockGemmFeaturesInterface>,
196205
RockFusionRoot, AttrSizedOperandSegments,
197206
AttrSizedResultSegments]>,
198207
Arguments<(ins TensorOrMemRefOf<[F32, F16, BF16, I8]>:$queries,
@@ -260,6 +269,7 @@ def Rock_GemmElementwiseGemmOp
260269
: Rock_Op<"gemm_elementwise_gemm",
261270
[DeclareOpInterfaceMethods<RockGemmGemmWrapperInterface>,
262271
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
272+
DeclareOpInterfaceMethods<RockGemmFeaturesInterface>,
263273
RockFusionRoot]>,
264274
AllElementTypesMatch<["a", "b"]>,
265275
Arguments<(ins TensorOrMemRefOf<[F32, F16, BF16]>:$a,
@@ -310,6 +320,7 @@ def Rock_ConvElementwiseGemmOp
310320
: Rock_Op<"conv_elementwise_gemm",
311321
[DeclareOpInterfaceMethods<RockGemmGemmWrapperInterface>,
312322
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
323+
DeclareOpInterfaceMethods<RockGemmFeaturesInterface>,
313324
RockFusionRoot]>,
314325
AllElementTypesMatch<["filter", "input"]>,
315326
Arguments<(ins TensorOrMemRefOf<[F32, F16, BF16]>:$filter,
@@ -486,7 +497,9 @@ def Rock_TensorUntransformCastOp :
486497
}
487498

488499
def Rock_GridwiseGemmOp
489-
: Rock_Op<"gridwise_gemm", [RockFusionRoot]>,
500+
: Rock_Op<"gridwise_gemm", [DeclareOpInterfaceMethods<
501+
RockGemmFeaturesInterface>,
502+
RockFusionRoot]>,
490503
Arguments<(ins Arg<MemRefRankOf<GemmInputTypes, [3]>,
491504
"matrix A view", [MemRead]>:$a,
492505
Arg<MemRefRankOf<GemmInputTypes, [3]>, "matrix B view", [MemRead]>:$b,
@@ -503,19 +516,13 @@ def Rock_GridwiseGemmOp
503516
$c `=` $a `*` $b `storeMethod` `(` $storeMethod `)` (`features` `=` $features^)? attr-dict `:` type($c) `=` type($a) `*` type($b)
504517
}];
505518
let hasVerifier = 1;
506-
507-
// Return the type from this op that is needed to calculate GemmFeatures
508-
let extraClassDeclaration = [{
509-
SmallVector<::mlir::Type> getTypesForFeature() {
510-
SmallVector<::mlir::Type> types = {getA().getType()};
511-
return types;
512-
}
513-
}];
514519
}
515520

516521
// gridwise_gemm_accel
517522
def Rock_GridwiseGemmAccelOp
518-
: Rock_Op<"gridwise_gemm_accel", [RockFusionRoot]>,
523+
: Rock_Op<"gridwise_gemm_accel", [DeclareOpInterfaceMethods<
524+
RockGemmFeaturesInterface>,
525+
RockFusionRoot]>,
519526
Arguments<(ins Arg<MemRefRankOf<GemmInputTypes, [3]>,
520527
"matrix A view", [MemRead]>:$a,
521528
Arg<MemRefRankOf<GemmInputTypes, [3]>, "matrix B view", [MemRead]>:$b,
@@ -532,21 +539,15 @@ def Rock_GridwiseGemmAccelOp
532539
`(` operands `)` `storeMethod` `(` $storeMethod `)` (`features` `=` $features^)? attr-dict `:` type(operands)
533540
}];
534541
let hasVerifier = 1;
535-
536-
// Return the type from this op that is needed to calculate GemmFeatures
537-
let extraClassDeclaration = [{
538-
SmallVector<::mlir::Type> getTypesForFeature() {
539-
SmallVector<::mlir::Type> types = {getA().getType()};
540-
return types;
541-
}
542-
}];
543542
}
544543

545544
// gridwise_attention_accel
546545
def Rock_GridwiseAttentionAccelOp
547546
: Rock_Op<"gridwise_attention_accel",
548547
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
549-
RockFusionRoot, AttrSizedOperandSegments]>,
548+
DeclareOpInterfaceMethods<RockGemmFeaturesInterface>,
549+
RockFusionRoot, AttrSizedOperandSegments,
550+
]>,
550551
Arguments<(ins MemRefRankOf<[F32, F16, BF16, I8], [3]>:$queries,
551552
MemRefRankOf<[F32, F16, BF16, I8], [3]>:$keys,
552553
MemRefRankOf<[F32, F16, BF16], [3]>:$values,
@@ -572,15 +573,6 @@ def Rock_GridwiseAttentionAccelOp
572573
`(` operands `)` (`features` `=` $features^)? `preSoftmaxOps` `=` $preSoftmaxBody attr-dict `:` type(operands)
573574
}];
574575
let hasVerifier = 1;
575-
576-
// Return the type from this op that is needed to calculate GemmFeatures
577-
let extraClassDeclaration = [{
578-
SmallVector<::mlir::Type> getTypesForFeature() {
579-
SmallVector<::mlir::Type> types = {getKeys().getType(),
580-
getValues().getType()};
581-
return types;
582-
}
583-
}];
584576
}
585577

586578
// Memory allocation on GPU memory hierachy.
@@ -1370,7 +1362,8 @@ defvar AccelResTypes = [VectorOfLengthAndType<[4, 8, 16, 32], [F32, I32, F16, BF
13701362

13711363
// blockwise_gemm_accel
13721364
def Rock_BlockwiseGemmAccelOp
1373-
: Rock_Op<"blockwise_gemm_accel">,
1365+
: Rock_Op<"blockwise_gemm_accel", [DeclareOpInterfaceMethods<
1366+
RockGemmFeaturesInterface>]>,
13741367
Arguments<(ins MemRefOf<LdsBufferTypes>:$matrixA,
13751368
MemRefOf<LdsBufferTypes>:$matrixB, I32Attr:$inMPerThread,
13761369
I32Attr:$inNPerThread, UnitAttr:$rotateMWithK, UnitAttr:$rotateNWithK,
@@ -1393,14 +1386,6 @@ def Rock_BlockwiseGemmAccelOp
13931386
`:` type($matrixC) `+` `` `=` type($bufferA) `from` type($matrixA) `*`
13941387
type($bufferB) `from` type($matrixB)
13951388
}];
1396-
1397-
// Return the type from this op that is needed to calculate GemmFeatures
1398-
let extraClassDeclaration = [{
1399-
SmallVector<::mlir::Type> getTypesForFeature() {
1400-
SmallVector<::mlir::Type> types = {getMatrixA().getType()};
1401-
return types;
1402-
}
1403-
}];
14041389
}
14051390

14061391
// threadwise_gemm
@@ -1426,7 +1411,8 @@ def Rock_ThreadwiseGemmOp:
14261411
}
14271412
// threadwise_accel_gemm
14281413
def Rock_ThreadwiseAccelGemmOp
1429-
: Rock_Op<"threadwise_accel_gemm">,
1414+
: Rock_Op<"threadwise_accel_gemm", [DeclareOpInterfaceMethods<
1415+
RockGemmFeaturesInterface>]>,
14301416
Arguments<(ins Arg<MemRefOf<NativeMemoryOpTypes>,
14311417
"source register view A", [MemRead]>:$matrixA,
14321418
Arg<MemRefOf<NativeMemoryOpTypes>,
@@ -1448,14 +1434,6 @@ def Rock_ThreadwiseAccelGemmOp
14481434
`:` type($matrixC) `+` `` `=` type($matrixA) `*` type($matrixB)
14491435
}];
14501436
let hasVerifier = 1;
1451-
1452-
// Return the type from this op that is needed to calculate GemmFeatures
1453-
let extraClassDeclaration = [{
1454-
SmallVector<::mlir::Type> getTypesForFeature() {
1455-
SmallVector<::mlir::Type> types = {getMatrixA().getType()};
1456-
return types;
1457-
}
1458-
}];
14591437
}
14601438

14611439
// blockwise_broadcasting_reduction

mlir/lib/Dialect/Rock/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_rocmlir_dialect_library(MLIRRockOps
44
GetRockInfo.cpp
55
TransformMapBuilder.cpp
66
RockDialect.cpp
7+
RockGemmFeaturesInterface.cpp
78
RockGemmWrapperInterface.cpp
89
RockGemmGemmWrapperInterface.cpp
910
RockConvInterface.cpp
@@ -19,6 +20,7 @@ add_rocmlir_dialect_library(MLIRRockOps
1920

2021
DEPENDS
2122
MLIRRockAttrDefsIncGen
23+
MLIRRockGemmFeaturesInterfaceIncGen
2224
MLIRRockGemmWrapperInterfaceIncGen
2325
MLIRRockGemmGemmWrapperInterfaceIncGen
2426
MLIRRockTuningParamAttrInterfaceIncGen

0 commit comments

Comments
 (0)