-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][math] powf(a, b) drop support when a < 0
#126338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
0e790b6
b248c22
23d2f09
0e7dc19
c52ba9f
e1e06ec
9cf3d3b
ec9f128
95c3d55
79c4ef4
393aaa8
f5205a6
d5ee522
48b9405
640ec45
e976ba6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,8 +17,13 @@ | |||||||||||||||||||||||||||||
| #include "mlir/Dialect/Vector/IR/VectorOps.h" | ||||||||||||||||||||||||||||||
| #include "mlir/IR/Builders.h" | ||||||||||||||||||||||||||||||
| #include "mlir/IR/ImplicitLocOpBuilder.h" | ||||||||||||||||||||||||||||||
| #include "mlir/IR/Matchers.h" | ||||||||||||||||||||||||||||||
| #include "mlir/IR/PatternMatch.h" | ||||||||||||||||||||||||||||||
| #include "mlir/IR/TypeUtilities.h" | ||||||||||||||||||||||||||||||
| #include "mlir/Transforms/DialectConversion.h" | ||||||||||||||||||||||||||||||
| #include "llvm/ADT/APFloat.h" | ||||||||||||||||||||||||||||||
| #include "llvm/Support/LogicalResult.h" | ||||||||||||||||||||||||||||||
| #include <cmath> | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| using namespace mlir; | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
@@ -311,40 +316,113 @@ static LogicalResult convertFPowIOp(math::FPowIOp op, | |||||||||||||||||||||||||||||
| return success(); | ||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| // Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a)) | ||||||||||||||||||||||||||||||
| // Convert Powf(float a, float b) for some special cases | ||||||||||||||||||||||||||||||
| // where b == 1.0, b == 0.0, b == 0.5, b == -0.5, b == -1.0, and b % 2 == 0 | ||||||||||||||||||||||||||||||
| static LogicalResult convertSpecialPowfOp(math::PowFOp op, | ||||||||||||||||||||||||||||||
| PatternRewriter &rewriter) { | ||||||||||||||||||||||||||||||
| ImplicitLocOpBuilder b(op->getLoc(), rewriter); | ||||||||||||||||||||||||||||||
| Value operandA = op.getOperand(0); | ||||||||||||||||||||||||||||||
| Value operandB = op.getOperand(1); | ||||||||||||||||||||||||||||||
| auto baseType = operandB.getType(); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType)) | ||||||||||||||||||||||||||||||
| .getFloatSemantics(); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| auto valueB = APFloat(sem); | ||||||||||||||||||||||||||||||
| if (!matchPattern(operandB, m_ConstantFloat(&valueB))) { | ||||||||||||||||||||||||||||||
| // Not a constant, return failure | ||||||||||||||||||||||||||||||
| return failure(); | ||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||
| float floatValueB = valueB.convertToFloat(); | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
| /// We don't rely on operator== working on double values, as | |
| /// it returns true for things that are clearly not equal, like -0.0 and 0.0. | |
| /// As such, this method can be used to do an exact bit-for-bit comparison of | |
| /// two floating point values. | |
| /// | |
| /// We leave the version with the double argument here because it's just so | |
| /// convenient to write "2.0" and the like. Without this function we'd | |
| /// have to duplicate its logic everywhere it's called. | |
| bool isExactlyValue(double V) const { | |
| bool ignored; | |
| APFloat Tmp(V); | |
| Tmp.convert(getSemantics(), APFloat::rmNearestTiesToEven, &ignored); | |
| return bitwiseIsEqual(Tmp); | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed!
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't try to handle arbitrary integer x. Just handle the special value 2.0, and maybe also 3.0 and 4.0 if you want, but that's it. If someone really needs a larger integral exponent to be match, we can always expand this pattern later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, let's handle few cases for now and document it in the function comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kept only |b|=2.0 case and removed other integer cases.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this actually checked? This seems to be expanding under this assumption, but always does it. Is this a new assumption on the op that should be documented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry!. I forgot to describe PR in more detail.
- I believe it should be documented, in general,
powf(a, b)wherea < 0generally yields NaN and we (as far as I know) aren't able to check it runtime.
- This transform should be applied to some small number of 'b' (e.g., when 'abs(b) < 16')
while (absIntValueB > 0) {
if (absIntValueB & 1) {
result = b.create<arith::MulFOp>(result, current);
}
current = b.create<arith::MulFOp>(current, current);
absIntValueB >>= 1;
}
rewriter.replaceOp(op, result);
return success();The heuristic number for b is not determined yet. This last case can be dropped if it's not necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One problem is that there are some special use-cases where var a < 0 but const b == some multiple of 2 cc @hanhanW
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just drop the comment // Restricting a >= 0 here.
Mathematically, the power operation a^b, is well-defined in two separate (though overlapping) cases:
- When
a > 0. In that case,a^bis defined asexp(b * ln(a)). - When
bis an integer. In that case,a^bis defined asa * ... * a, (btimes), or the reciprocal of that ifbis negative.
These two definitions agree in the intersection of these two cases.
Because "power" has inherently that two-mode definition, the MLIR op powf should have been specified from the start to implement one of these two modes only. Obviously it should have been a > 0.
I believe that it is still time to clarify that. We have observed recently that some rewrite patterns for powf ops have been broken outside of the case a > 0, suggesting that no one was relying on that.
But that discussion doesn't need to be conflated into this PR, because this PR implements rewrites that are either agnostic as to which case we are in (e.g. the case of pow(a, 2.0)) or that are explicitly not applying to the other case anyway (e.g. the case of pow(a, 0.5)).
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question for MLIR experts: Here, we want the convertSpecialPowfOp to have precedence over the convertPowfOp pattern. Is that ensured by it being added first here? If not, do we need to merge these two patterns to ensure ordering?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
patterns.add(convertSpecialPowfOp, /*benefit=*/ 2);
This would explicitly give the order we want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not very sure since I didn't check the code, but adding patterns in this order make convertSpecialPowfOp run first.
patterns.add(convertSpecialPowfOp);
patterns.add(convertPowfOp);There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hanhanW ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not very sure since I didn't check the code, but adding patterns in this order make convertSpecialPowfOp run first.
This is correct, but I think it is not documented. IMO, we prefer using benefit to prioritize the patterns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to give explicit benefit=2!
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At least we need all the special cases are tested in this file.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my apology for that tihs PR got verbose, I made appropriate tests and runs well!! |
ita9naiwa marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, since
math.powfrequires float arguments (I just checked MathOps.td), the cast really shouldn't ever fail, so I think you can simply usecastinstead ofdyn_cast. You weren't checking for a null return value fromdyn_castanyway.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!