Skip to content

Commit 4b25005

Browse files
authored
feat: dot_general_licm (#1550)
1 parent b3d6fde commit 4b25005

File tree

7 files changed

+160
-5
lines changed

7 files changed

+160
-5
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25018,6 +25018,13 @@ void mlir::transform::addSliceLICM(RewritePatternSet &patterns,
2501825018
patterns.insert<LICM<stablehlo::SliceOp>>(single_user, &context, benefit);
2501925019
}
2502025020

25021+
void mlir::transform::addDotGeneralLICM(RewritePatternSet &patterns,
25022+
bool single_user, MLIRContext &context,
25023+
PatternBenefit benefit) {
25024+
patterns.insert<LICM<stablehlo::DotGeneralOp>>(single_user, &context,
25025+
benefit);
25026+
}
25027+
2502125028
void mlir::transform::addDUSLICM(RewritePatternSet &patterns, bool single_user,
2502225029
MLIRContext &context, PatternBenefit benefit) {
2502325030
patterns.insert<LICM<stablehlo::DynamicUpdateSliceOp>>(single_user, &context,

src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ void addWhileLICM(RewritePatternSet &patterns, bool hoist_all,
5959
MLIRContext &context, PatternBenefit benefit);
6060
void addSliceLICM(RewritePatternSet &patterns, bool single_user,
6161
MLIRContext &context, PatternBenefit benefit);
62+
void addDotGeneralLICM(RewritePatternSet &patterns, bool single_user,
63+
MLIRContext &context, PatternBenefit benefit);
6264
void addDUSLICM(RewritePatternSet &patterns, bool single_user,
6365
MLIRContext &context, PatternBenefit benefit);
6466
void addPadLICM(RewritePatternSet &patterns, bool single_user,

src/enzyme_ad/jax/TransformOps/GenerateApplyPatterns.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,11 @@ LogicalResult parseTransform(OpBuilder &builder, Location loc,
149149
if (parameter != -1) {
150150
if (opName == "no_nan_add_sub_simplify" || opName == "while_simplify" ||
151151
opName == "sum_to_conv" || opName == "while_licm" ||
152-
opName == "slice_licm" || opName == "dus_licm" ||
153-
opName == "pad_licm" || opName == "elementwise_licm" ||
154-
opName == "concatenate_licm" || opName == "broadcastindim_licm" ||
155-
opName == "reshape_licm" || opName == "transpose_licm" ||
156-
opName == "transpose_elementwise" ||
152+
opName == "slice_licm" || opName == "dot_general_licm" ||
153+
opName == "dus_licm" || opName == "pad_licm" ||
154+
opName == "elementwise_licm" || opName == "concatenate_licm" ||
155+
opName == "broadcastindim_licm" || opName == "reshape_licm" ||
156+
opName == "transpose_licm" || opName == "transpose_elementwise" ||
157157
opName == "reshape_elementwise" || opName == "reshape_slice" ||
158158
opName == "reshape_dynamic_slice" ||
159159
opName == "extend_unary_elementwise" ||

src/enzyme_ad/jax/TransformOps/TransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ void ApplySliceLICMPatterns::populatePatterns(RewritePatternSet &patterns) {
8080
addSliceLICM(patterns, getParameter(), *getContext(),
8181
PatternBenefit(getBenefit().value_or(1)));
8282
}
83+
void ApplyDotGeneralLICMPatterns::populatePatterns(
84+
RewritePatternSet &patterns) {
85+
addDotGeneralLICM(patterns, getParameter(), *getContext(),
86+
PatternBenefit(getBenefit().value_or(1)));
87+
}
8388
void ApplyDUSLICMPatterns::populatePatterns(RewritePatternSet &patterns) {
8489
addDUSLICM(patterns, getParameter(), *getContext(),
8590
PatternBenefit(getBenefit().value_or(1)));

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,21 @@ def ApplySliceLICMPatterns : EnzymeHLOParameterizedPatternOp<
11211121
}];
11221122
}
11231123

1124+
def ApplyDotGeneralLICMPatterns : EnzymeHLOParameterizedPatternOp<
1125+
"dot_general_licm"> {
1126+
let arguments = (ins OptionalAttr<I64Attr>:$benefit, BoolAttr:$parameter);
1127+
let assemblyFormat = "attr-dict";
1128+
// TODO: this should be made better searchable.
1129+
let extraClassDeclaration = [{
1130+
::llvm::SmallVector<::mlir::DictionaryAttr>
1131+
static getPossibleAttrCombinations(::mlir::Builder &builder) {
1132+
return {builder.getDictionaryAttr(
1133+
builder.getNamedAttr("parameter",
1134+
builder.getBoolAttr(true)))};
1135+
}
1136+
}];
1137+
}
1138+
11241139
def ApplyDUSLICMPatterns : EnzymeHLOParameterizedPatternOp<
11251140
"dus_licm"> {
11261141
let arguments = (ins OptionalAttr<I64Attr>:$benefit, BoolAttr:$parameter);

src/enzyme_ad/jax/primitives.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def optimization_passes(
283283
"slice_dus_to_concat",
284284
"while_induction_reduction",
285285
"slice_licm(0)",
286+
"dot_general_licm(0)",
286287
"pad_licm(0)",
287288
"elementwise_licm(0)",
288289
"concatenate_licm(0)",

0 commit comments

Comments
 (0)