|
16 | 16 | import dataclasses
|
17 | 17 | import functools
|
18 | 18 | import itertools
|
19 |
| -from typing import Mapping, Optional, Sequence, Tuple |
| 19 | +from typing import cast, Mapping, Optional, Sequence, Tuple |
20 | 20 |
|
21 | 21 | import bigframes.core.expression as scalar_exprs
|
| 22 | +import bigframes.core.guid as guids |
22 | 23 | import bigframes.core.identifiers as ids
|
23 | 24 | import bigframes.core.join_def as join_defs
|
24 | 25 | import bigframes.core.nodes as nodes
|
25 | 26 | import bigframes.core.ordering as order
|
| 27 | +import bigframes.core.tree_properties as traversals |
26 | 28 | import bigframes.operations as ops
|
27 | 29 |
|
28 | 30 | Selection = Tuple[Tuple[scalar_exprs.Expression, ids.ColumnId], ...]
|
@@ -381,3 +383,172 @@ def common_selection_root(
|
381 | 383 | if r_node in l_nodes:
|
382 | 384 | return r_node
|
383 | 385 | return None
|
| 386 | + |
| 387 | + |
| 388 | +def replace_slice_ops(root: nodes.BigFrameNode) -> nodes.BigFrameNode: |
| 389 | + # TODO: we want to pull up some slices into limit op if near root. |
| 390 | + if isinstance(root, nodes.SliceNode): |
| 391 | + root = root.transform_children(replace_slice_ops) |
| 392 | + return convert_slice_to_filter(cast(nodes.SliceNode, root)) |
| 393 | + else: |
| 394 | + return root.transform_children(replace_slice_ops) |
| 395 | + |
| 396 | + |
| 397 | +def get_simplified_slice(node: nodes.SliceNode): |
| 398 | + """Attempts to simplify the slice.""" |
| 399 | + row_count = traversals.row_count(node) |
| 400 | + start, stop, step = node.start, node.stop, node.step |
| 401 | + |
| 402 | + if start is None: |
| 403 | + start = 0 if step > 0 else -1 |
| 404 | + if row_count and step > 0: |
| 405 | + if start and start < 0: |
| 406 | + start = row_count + start |
| 407 | + if stop and stop < 0: |
| 408 | + stop = row_count + stop |
| 409 | + return start, stop, step |
| 410 | + |
| 411 | + |
| 412 | +def convert_slice_to_filter(node: nodes.SliceNode): |
| 413 | + start, stop, step = get_simplified_slice(node) |
| 414 | + |
| 415 | + # no-op (eg. df[::1]) |
| 416 | + if ( |
| 417 | + ((start == 0) or (start is None)) |
| 418 | + and ((stop is None) or (stop == -1)) |
| 419 | + and (step == 1) |
| 420 | + ): |
| 421 | + return node.child |
| 422 | + # No filtering, just reverse (eg. df[::-1]) |
| 423 | + if ((start is None) or (start == -1)) and (not stop) and (step == -1): |
| 424 | + return nodes.ReversedNode(node.child) |
| 425 | + # if start/stop/step are all non-negative, and do a simple predicate on forward offsets |
| 426 | + if ((start is None) or (start >= 0)) and ((stop is None) or (stop >= 0)): |
| 427 | + node_w_offset = add_offsets(node.child) |
| 428 | + predicate = convert_simple_slice( |
| 429 | + scalar_exprs.DerefOp(node_w_offset.col_id), start or 0, stop, step |
| 430 | + ) |
| 431 | + filtered = nodes.FilterNode(node_w_offset, predicate) |
| 432 | + return drop_cols(filtered, (node_w_offset.col_id,)) |
| 433 | + |
| 434 | + # fallback cases, generate both forward and backward offsets |
| 435 | + if step < 0: |
| 436 | + forward_offsets = add_offsets(node.child) |
| 437 | + reversed_offsets = add_offsets(nodes.ReversedNode(forward_offsets)) |
| 438 | + dual_indexed = reversed_offsets |
| 439 | + else: |
| 440 | + reversed_offsets = add_offsets(nodes.ReversedNode(node.child)) |
| 441 | + forward_offsets = add_offsets(nodes.ReversedNode(reversed_offsets)) |
| 442 | + dual_indexed = forward_offsets |
| 443 | + predicate = convert_complex_slice( |
| 444 | + scalar_exprs.DerefOp(forward_offsets.col_id), |
| 445 | + scalar_exprs.DerefOp(reversed_offsets.col_id), |
| 446 | + start, |
| 447 | + stop, |
| 448 | + step, |
| 449 | + ) |
| 450 | + filtered = nodes.FilterNode(dual_indexed, predicate) |
| 451 | + return drop_cols(filtered, (forward_offsets.col_id, reversed_offsets.col_id)) |
| 452 | + |
| 453 | + |
| 454 | +def add_offsets(node: nodes.BigFrameNode) -> nodes.PromoteOffsetsNode: |
| 455 | + # Allow providing custom id generator? |
| 456 | + offsets_id = ids.ColumnId(guids.generate_guid()) |
| 457 | + return nodes.PromoteOffsetsNode(node, offsets_id) |
| 458 | + |
| 459 | + |
| 460 | +def drop_cols( |
| 461 | + node: nodes.BigFrameNode, drop_cols: Tuple[ids.ColumnId, ...] |
| 462 | +) -> nodes.SelectionNode: |
| 463 | + # adding a whole node that redefines the schema is a lot of overhead, should do something more efficient |
| 464 | + selections = tuple( |
| 465 | + (scalar_exprs.DerefOp(id), id) for id in node.ids if id not in drop_cols |
| 466 | + ) |
| 467 | + return nodes.SelectionNode(node, selections) |
| 468 | + |
| 469 | + |
| 470 | +def convert_simple_slice( |
| 471 | + offsets: scalar_exprs.Expression, |
| 472 | + start: int = 0, |
| 473 | + stop: Optional[int] = None, |
| 474 | + step: int = 1, |
| 475 | +) -> scalar_exprs.Expression: |
| 476 | + """Performs slice but only for positive step size.""" |
| 477 | + assert start >= 0 |
| 478 | + assert (stop is None) or (stop >= 0) |
| 479 | + |
| 480 | + conditions = [] |
| 481 | + if start > 0: |
| 482 | + conditions.append(ops.ge_op.as_expr(offsets, scalar_exprs.const(start))) |
| 483 | + if (stop is not None) and (stop >= 0): |
| 484 | + conditions.append(ops.lt_op.as_expr(offsets, scalar_exprs.const(stop))) |
| 485 | + if step > 1: |
| 486 | + start_diff = ops.sub_op.as_expr(offsets, scalar_exprs.const(start)) |
| 487 | + step_cond = ops.eq_op.as_expr( |
| 488 | + ops.mod_op.as_expr(start_diff, scalar_exprs.const(step)), |
| 489 | + scalar_exprs.const(0), |
| 490 | + ) |
| 491 | + conditions.append(step_cond) |
| 492 | + |
| 493 | + return merge_predicates(conditions) or scalar_exprs.const(True) |
| 494 | + |
| 495 | + |
| 496 | +def convert_complex_slice( |
| 497 | + forward_offsets: scalar_exprs.Expression, |
| 498 | + reverse_offsets: scalar_exprs.Expression, |
| 499 | + start: int, |
| 500 | + stop: Optional[int], |
| 501 | + step: int = 1, |
| 502 | +) -> scalar_exprs.Expression: |
| 503 | + conditions = [] |
| 504 | + assert step != 0 |
| 505 | + if start or ((start is not None) and step < 0): |
| 506 | + if start > 0 and step > 0: |
| 507 | + start_cond = ops.ge_op.as_expr(forward_offsets, scalar_exprs.const(start)) |
| 508 | + elif start > 0 and step < 0: |
| 509 | + start_cond = ops.le_op.as_expr(forward_offsets, scalar_exprs.const(start)) |
| 510 | + elif start < 0 and step > 0: |
| 511 | + start_cond = ops.le_op.as_expr( |
| 512 | + reverse_offsets, scalar_exprs.const(-start - 1) |
| 513 | + ) |
| 514 | + else: |
| 515 | + assert start < 0 and step < 0 |
| 516 | + start_cond = ops.ge_op.as_expr( |
| 517 | + reverse_offsets, scalar_exprs.const(-start - 1) |
| 518 | + ) |
| 519 | + conditions.append(start_cond) |
| 520 | + if stop is not None: |
| 521 | + if stop >= 0 and step > 0: |
| 522 | + stop_cond = ops.lt_op.as_expr(forward_offsets, scalar_exprs.const(stop)) |
| 523 | + elif stop >= 0 and step < 0: |
| 524 | + stop_cond = ops.gt_op.as_expr(forward_offsets, scalar_exprs.const(stop)) |
| 525 | + elif stop < 0 and step > 0: |
| 526 | + stop_cond = ops.gt_op.as_expr( |
| 527 | + reverse_offsets, scalar_exprs.const(-stop - 1) |
| 528 | + ) |
| 529 | + else: |
| 530 | + assert (stop < 0) and (step < 0) |
| 531 | + stop_cond = ops.lt_op.as_expr( |
| 532 | + reverse_offsets, scalar_exprs.const(-stop - 1) |
| 533 | + ) |
| 534 | + conditions.append(stop_cond) |
| 535 | + if step != 1: |
| 536 | + if step > 1 and start >= 0: |
| 537 | + start_diff = ops.sub_op.as_expr(forward_offsets, scalar_exprs.const(start)) |
| 538 | + elif step > 1 and start < 0: |
| 539 | + start_diff = ops.sub_op.as_expr( |
| 540 | + reverse_offsets, scalar_exprs.const(-start + 1) |
| 541 | + ) |
| 542 | + elif step < 0 and start >= 0: |
| 543 | + start_diff = ops.add_op.as_expr(forward_offsets, scalar_exprs.const(start)) |
| 544 | + else: |
| 545 | + assert step < 0 and start < 0 |
| 546 | + start_diff = ops.add_op.as_expr( |
| 547 | + reverse_offsets, scalar_exprs.const(-start + 1) |
| 548 | + ) |
| 549 | + step_cond = ops.eq_op.as_expr( |
| 550 | + ops.mod_op.as_expr(start_diff, scalar_exprs.const(step)), |
| 551 | + scalar_exprs.const(0), |
| 552 | + ) |
| 553 | + conditions.append(step_cond) |
| 554 | + return merge_predicates(conditions) or scalar_exprs.const(True) |
0 commit comments