Skip to content

Commit b9314a8

Browse files
authored
[mlir][spirv] Update math.powf lowering (#111388)
The PR updates math.powf lowering to produce NaN result for a negative base with a fractional exponent which matches the actual behaviour of the C/C++ implementation.
1 parent fed8695 commit b9314a8

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,19 +377,42 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
377377
// Get int type of the same shape as the float type.
378378
Type scalarIntType = rewriter.getIntegerType(32);
379379
Type intType = scalarIntType;
380-
if (auto vectorType = dyn_cast<VectorType>(adaptor.getRhs().getType())) {
380+
auto operandType = adaptor.getRhs().getType();
381+
if (auto vectorType = dyn_cast<VectorType>(operandType)) {
381382
auto shape = vectorType.getShape();
382383
intType = VectorType::get(shape, scalarIntType);
383384
}
384385

385386
// Per GL Pow extended instruction spec:
386387
// "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
387388
Location loc = powfOp.getLoc();
388-
Value zero =
389-
spirv::ConstantOp::getZero(adaptor.getLhs().getType(), loc, rewriter);
389+
Value zero = spirv::ConstantOp::getZero(operandType, loc, rewriter);
390390
Value lessThan =
391391
rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
392-
Value abs = rewriter.create<spirv::GLFAbsOp>(loc, adaptor.getLhs());
392+
393+
// Per C/C++ spec:
394+
// > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
395+
// > finite and negative and exponent is finite and non-integer.
396+
// Calculate the reminder from the exponent and check whether it is zero.
397+
Value floatOne = spirv::ConstantOp::getOne(operandType, loc, rewriter);
398+
Value expRem =
399+
rewriter.create<spirv::FRemOp>(loc, adaptor.getRhs(), floatOne);
400+
Value expRemNonZero =
401+
rewriter.create<spirv::FOrdNotEqualOp>(loc, expRem, zero);
402+
Value cmpNegativeWithFractionalExp =
403+
rewriter.create<spirv::LogicalAndOp>(loc, expRemNonZero, lessThan);
404+
// Create NaN result and replace base value if conditions are met.
405+
const auto &floatSemantics = scalarFloatType.getFloatSemantics();
406+
const auto nan = APFloat::getNaN(floatSemantics);
407+
Attribute nanAttr = rewriter.getFloatAttr(scalarFloatType, nan);
408+
if (auto vectorType = dyn_cast<VectorType>(operandType))
409+
nanAttr = DenseElementsAttr::get(vectorType, nan);
410+
411+
Value NanValue =
412+
rewriter.create<spirv::ConstantOp>(loc, operandType, nanAttr);
413+
Value lhs = rewriter.create<spirv::SelectOp>(
414+
loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs());
415+
Value abs = rewriter.create<spirv::GLFAbsOp>(loc, lhs);
393416

394417
// TODO: The following just forcefully casts y into an integer value in
395418
// order to properly propagate the sign, assuming integer y cases. It

mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,13 @@ func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> {
156156
func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 {
157157
// CHECK: %[[F0:.+]] = spirv.Constant 0.000000e+00 : f32
158158
// CHECK: %[[LT:.+]] = spirv.FOrdLessThan %[[LHS]], %[[F0]] : f32
159-
// CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[LHS]] : f32
159+
// CHECK: %[[F1:.+]] = spirv.Constant 1.000000e+00 : f32
160+
// CHECK: %[[REM:.+]] = spirv.FRem %[[RHS]], %[[F1]] : f32
161+
// CHECK: %[[IS_FRACTION:.+]] = spirv.FOrdNotEqual %[[REM]], %[[F0]] : f32
162+
// CHECK: %[[AND:.+]] = spirv.LogicalAnd %[[IS_FRACTION]], %[[LT]] : i1
163+
// CHECK: %[[NAN:.+]] = spirv.Constant 0x7FC00000 : f32
164+
// CHECK: %[[NEW_LHS:.+]] = spirv.Select %[[AND]], %[[NAN]], %[[LHS]] : i1, f32
165+
// CHECK: %[[ABS:.+]] = spirv.GL.FAbs %[[NEW_LHS]] : f32
160166
// CHECK: %[[IRHS:.+]] = spirv.ConvertFToS
161167
// CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
162168
// CHECK: %[[REM:.+]] = spirv.BitwiseAnd %[[IRHS]]
@@ -173,6 +179,10 @@ func.func @powf_scalar(%lhs: f32, %rhs: f32) -> f32 {
173179
// CHECK-LABEL: @powf_vector
174180
func.func @powf_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) -> vector<4xf32> {
175181
// CHECK: spirv.FOrdLessThan
182+
// CHECK: spirv.FRem
183+
// CHECK: spirv.FOrdNotEqual
184+
// CHECK: spirv.LogicalAnd
185+
// CHECK: spirv.Select
176186
// CHECK: spirv.GL.FAbs
177187
// CHECK: spirv.BitwiseAnd %{{.*}} : vector<4xi32>
178188
// CHECK: spirv.IEqual %{{.*}} : vector<4xi32>

0 commit comments

Comments
 (0)