Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions bigframes/core/compile/sqlglot/expressions/numeric_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,141 @@ def _(expr: TypedExpr) -> sge.Expression:
return expr.expr


@register_binary_op(ops.pow_op)
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
left_expr = _coerce_bool_to_int(left)
right_expr = _coerce_bool_to_int(right)
if left.dtype == dtypes.INT_DTYPE and right.dtype == dtypes.INT_DTYPE:
return _int_pow_op(left_expr, right_expr)
else:
return _float_pow_op(left_expr, right_expr)


def _int_pow_op(
left_expr: sge.Expression, right_expr: sge.Expression
) -> sge.Expression:
import math

overflow_value = math.log(2**63 - 1)
overflow_cond = sge.and_(
sge.NEQ(this=left_expr, expression=sge.convert(0)),
sge.GT(
this=sge.Mul(
this=right_expr, expression=sge.Ln(this=sge.Abs(this=left_expr))
),
expression=sge.convert(overflow_value),
),
)

return sge.Case(
ifs=[
sge.If(
this=overflow_cond,
true=sge.Null(),
)
],
default=sge.Cast(
this=sge.Pow(
this=sge.Cast(
this=left_expr, to=sge.DataType(this=sge.DataType.Type.DECIMAL)
),
expression=right_expr,
),
to="INT64",
),
)


def _float_pow_op(
left_expr: sge.Expression, right_expr: sge.Expression
) -> sge.Expression:
# Most conditions here seek to prevent calling BQ POW with inputs that would generate errors.
# See: https://cloud.google.com/bigquery/docs/reference/standard-sql/mathematical_functions#pow
overflow_cond = sge.and_(
sge.NEQ(this=left_expr, expression=constants._ZERO),
sge.GT(
this=sge.Mul(
this=right_expr, expression=sge.Ln(this=sge.Abs(this=left_expr))
),
expression=constants._FLOAT64_EXP_BOUND,
),
)

# Float64 lose integer precision beyond 2**53, beyond this insufficient precision to get parity
exp_too_big = sge.GT(this=sge.Abs(this=right_expr), expression=sge.convert(2**53))
# Treat very large exponents as +=INF
norm_exp = sge.Case(
ifs=[
sge.If(
this=exp_too_big,
true=sge.Mul(this=constants._INF, expression=sge.Sign(this=right_expr)),
)
],
default=right_expr,
)

pow_result = sge.Pow(this=left_expr, expression=norm_exp)

# This cast is dangerous, need to only excuted where y_val has been bounds-checked
# Ibis needs try_cast binding to bq safe_cast
exponent_is_whole = sge.EQ(
this=sge.Cast(this=right_expr, to="INT64"), expression=right_expr
)
odd_exponent = sge.and_(
sge.LT(this=left_expr, expression=constants._ZERO),
sge.EQ(
this=sge.Mod(
this=sge.Cast(this=right_expr, to="INT64"), expression=sge.convert(2)
),
expression=sge.convert(1),
),
)
infinite_base = sge.EQ(this=sge.Abs(this=left_expr), expression=constants._INF)

return sge.Case(
ifs=[
# Might be able to do something more clever with x_val==0 case
sge.If(
this=sge.EQ(this=right_expr, expression=constants._ZERO),
true=sge.convert(1),
),
sge.If(
this=sge.EQ(this=left_expr, expression=sge.convert(1)),
true=sge.convert(1),
), # Need to ignore exponent, even if it is NA
sge.If(
this=sge.and_(
sge.EQ(this=left_expr, expression=constants._ZERO),
sge.LT(this=right_expr, expression=constants._ZERO),
),
true=constants._INF,
), # This case would error POW function in BQ
sge.If(this=infinite_base, true=pow_result),
sge.If(
this=exp_too_big, true=pow_result
), # Bigquery can actually handle the +-inf cases gracefully
sge.If(
this=sge.and_(
sge.LT(this=left_expr, expression=constants._ZERO),
sge.Not(this=exponent_is_whole),
),
true=constants._NAN,
),
sge.If(
this=overflow_cond,
true=sge.Mul(
this=constants._INF,
expression=sge.Case(
ifs=[sge.If(this=odd_exponent, true=sge.convert(-1))],
default=sge.convert(1),
),
),
), # finite overflows would cause bq to error
],
default=pow_result,
)


@register_unary_op(ops.sqrt_op)
def _(expr: TypedExpr) -> sge.Expression:
return sge.Case(
Expand Down
Loading