@@ -377,19 +377,42 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
377377 // Get int type of the same shape as the float type.
378378 Type scalarIntType = rewriter.getIntegerType (32 );
379379 Type intType = scalarIntType;
380- if (auto vectorType = dyn_cast<VectorType>(adaptor.getRhs ().getType ())) {
380+ auto operandType = adaptor.getRhs ().getType ();
381+ if (auto vectorType = dyn_cast<VectorType>(operandType)) {
381382 auto shape = vectorType.getShape ();
382383 intType = VectorType::get (shape, scalarIntType);
383384 }
384385
385386 // Per GL Pow extended instruction spec:
386387 // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
387388 Location loc = powfOp.getLoc ();
388- Value zero =
389- spirv::ConstantOp::getZero (adaptor.getLhs ().getType (), loc, rewriter);
389+ Value zero = spirv::ConstantOp::getZero (operandType, loc, rewriter);
390390 Value lessThan =
391391 rewriter.create <spirv::FOrdLessThanOp>(loc, adaptor.getLhs (), zero);
392- Value abs = rewriter.create <spirv::GLFAbsOp>(loc, adaptor.getLhs ());
392+
393+ // Per C/C++ spec:
394+ // > pow(base, exponent) returns NaN (and raises FE_INVALID) if base is
395+ // > finite and negative and exponent is finite and non-integer.
396+ // Calculate the reminder from the exponent and check whether it is zero.
397+ Value floatOne = spirv::ConstantOp::getOne (operandType, loc, rewriter);
398+ Value expRem =
399+ rewriter.create <spirv::FRemOp>(loc, adaptor.getRhs (), floatOne);
400+ Value expRemNonZero =
401+ rewriter.create <spirv::FOrdNotEqualOp>(loc, expRem, zero);
402+ Value cmpNegativeWithFractionalExp =
403+ rewriter.create <spirv::LogicalAndOp>(loc, expRemNonZero, lessThan);
404+ // Create NaN result and replace base value if conditions are met.
405+ const auto &floatSemantics = scalarFloatType.getFloatSemantics ();
406+ const auto nan = APFloat::getNaN (floatSemantics);
407+ Attribute nanAttr = rewriter.getFloatAttr (scalarFloatType, nan);
408+ if (auto vectorType = dyn_cast<VectorType>(operandType))
409+ nanAttr = DenseElementsAttr::get (vectorType, nan);
410+
411+ Value NanValue =
412+ rewriter.create <spirv::ConstantOp>(loc, operandType, nanAttr);
413+ Value lhs = rewriter.create <spirv::SelectOp>(
414+ loc, cmpNegativeWithFractionalExp, NanValue, adaptor.getLhs ());
415+ Value abs = rewriter.create <spirv::GLFAbsOp>(loc, lhs);
393416
394417 // TODO: The following just forcefully casts y into an integer value in
395418 // order to properly propagate the sign, assuming integer y cases. It
0 commit comments