From 0b4c2d7ac916e1a39d812fa76d68ae73fcbe4d94 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 28 Aug 2025 17:48:35 +0000 Subject: [PATCH] test: add unit test for remap_variables --- .../{test_rewrite.py => rewrite/conftest.py} | 36 ++--- tests/unit/core/rewrite/test_identifiers.py | 132 ++++++++++++++++++ tests/unit/core/rewrite/test_slices.py | 34 +++++ 3 files changed, 181 insertions(+), 21 deletions(-) rename tests/unit/core/{test_rewrite.py => rewrite/conftest.py} (56%) create mode 100644 tests/unit/core/rewrite/test_identifiers.py create mode 100644 tests/unit/core/rewrite/test_slices.py diff --git a/tests/unit/core/test_rewrite.py b/tests/unit/core/rewrite/conftest.py similarity index 56% rename from tests/unit/core/test_rewrite.py rename to tests/unit/core/rewrite/conftest.py index 1f1a2c3db9..22b897f3bf 100644 --- a/tests/unit/core/test_rewrite.py +++ b/tests/unit/core/rewrite/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,10 +14,9 @@ import unittest.mock as mock import google.cloud.bigquery +import pytest import bigframes.core as core -import bigframes.core.nodes as nodes -import bigframes.core.rewrite.slices import bigframes.core.schema TABLE_REF = google.cloud.bigquery.TableReference.from_string("project.dataset.table") @@ -31,27 +30,22 @@ ) FAKE_SESSION = mock.create_autospec(bigframes.Session, instance=True) type(FAKE_SESSION)._strictly_ordered = mock.PropertyMock(return_value=True) -LEAF = core.ArrayValue.from_table( - session=FAKE_SESSION, - table=TABLE, - schema=bigframes.core.schema.ArraySchema.from_bq_table(TABLE), -).node -def test_rewrite_noop_slice(): - slice = nodes.SliceNode(LEAF, None, None) - result = bigframes.core.rewrite.slices.rewrite_slice(slice) - assert result == LEAF +@pytest.fixture +def table(): + return TABLE -def test_rewrite_reverse_slice(): - slice = nodes.SliceNode(LEAF, None, None, -1) - result = bigframes.core.rewrite.slices.rewrite_slice(slice) - assert result == nodes.ReversedNode(LEAF) +@pytest.fixture +def fake_session(): + return FAKE_SESSION -def test_rewrite_filter_slice(): - slice = nodes.SliceNode(LEAF, None, 2) - result = bigframes.core.rewrite.slices.rewrite_slice(slice) - assert list(result.fields) == list(LEAF.fields) - assert isinstance(result.child, nodes.FilterNode) +@pytest.fixture +def leaf(fake_session, table): + return core.ArrayValue.from_table( + session=fake_session, + table=table, + schema=bigframes.core.schema.ArraySchema.from_bq_table(table), + ).node diff --git a/tests/unit/core/rewrite/test_identifiers.py b/tests/unit/core/rewrite/test_identifiers.py new file mode 100644 index 0000000000..fd12df60a8 --- /dev/null +++ b/tests/unit/core/rewrite/test_identifiers.py @@ -0,0 +1,132 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import bigframes.core as core +import bigframes.core.identifiers as identifiers +import bigframes.core.nodes as nodes +import bigframes.core.rewrite.identifiers as id_rewrite + + +def test_remap_variables_single_node(leaf): + node = leaf + id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + new_node, mapping = id_rewrite.remap_variables(node, id_generator) + assert new_node is not node + assert len(mapping) == 2 + assert set(mapping.keys()) == {f.id for f in node.fields} + assert set(mapping.values()) == { + identifiers.ColumnId("id_0"), + identifiers.ColumnId("id_1"), + } + + +def test_remap_variables_projection(leaf): + node = nodes.ProjectionNode( + leaf, + ( + ( + core.expression.DerefOp(leaf.fields[0].id), + identifiers.ColumnId("new_col"), + ), + ), + ) + id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + new_node, mapping = id_rewrite.remap_variables(node, id_generator) + assert new_node is not node + assert len(mapping) == 3 + assert set(mapping.keys()) == {f.id for f in node.fields} + assert set(mapping.values()) == {identifiers.ColumnId(f"id_{i}") for i in range(3)} + + +def test_remap_variables_nested_join_stability(leaf, fake_session, table): + # Create two more distinct leaf nodes + leaf2_uncached = core.ArrayValue.from_table( + session=fake_session, + table=table, + schema=leaf.schema, + ).node + leaf2 = leaf2_uncached.remap_vars( + { + field.id: identifiers.ColumnId(f"leaf2_{field.id.name}") + for field in leaf2_uncached.fields + } + ) + leaf3_uncached = core.ArrayValue.from_table( + session=fake_session, + table=table, + schema=leaf.schema, + ).node + leaf3 = leaf3_uncached.remap_vars( + { + field.id: identifiers.ColumnId(f"leaf3_{field.id.name}") + for field in leaf3_uncached.fields + } + ) + + # Create a nested join: (leaf JOIN leaf2) JOIN leaf3 + inner_join = nodes.JoinNode( + left_child=leaf, + right_child=leaf2, + conditions=( + ( + core.expression.DerefOp(leaf.fields[0].id), + core.expression.DerefOp(leaf2.fields[0].id), + ), + ), + type="inner", + propogate_order=False, + ) + outer_join = nodes.JoinNode( + left_child=inner_join, + right_child=leaf3, + conditions=( + ( + core.expression.DerefOp(inner_join.fields[0].id), + core.expression.DerefOp(leaf3.fields[0].id), + ), + ), + type="inner", + propogate_order=False, + ) + + # Run remap_variables twice and assert stability + id_generator1 = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + new_node1, mapping1 = id_rewrite.remap_variables(outer_join, id_generator1) + + id_generator2 = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + new_node2, mapping2 = id_rewrite.remap_variables(outer_join, id_generator2) + + assert new_node1 == new_node2 + assert mapping1 == mapping2 + + +def test_remap_variables_concat_self_stability(leaf): + # Create a concat node with the same child twice + node = nodes.ConcatNode( + children=(leaf, leaf), + output_ids=( + identifiers.ColumnId("concat_a"), + identifiers.ColumnId("concat_b"), + ), + ) + + # Run remap_variables twice and assert stability + id_generator1 = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + new_node1, mapping1 = id_rewrite.remap_variables(node, id_generator1) + + id_generator2 = (identifiers.ColumnId(f"id_{i}") for i in range(100)) + new_node2, mapping2 = id_rewrite.remap_variables(node, id_generator2) + + assert new_node1 == new_node2 + assert mapping1 == mapping2 diff --git a/tests/unit/core/rewrite/test_slices.py b/tests/unit/core/rewrite/test_slices.py new file mode 100644 index 0000000000..6d49ffb80a --- /dev/null +++ b/tests/unit/core/rewrite/test_slices.py @@ -0,0 +1,34 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import bigframes.core.nodes as nodes +import bigframes.core.rewrite.slices + + +def test_rewrite_noop_slice(leaf): + slice = nodes.SliceNode(leaf, None, None) + result = bigframes.core.rewrite.slices.rewrite_slice(slice) + assert result == leaf + + +def test_rewrite_reverse_slice(leaf): + slice = nodes.SliceNode(leaf, None, None, -1) + result = bigframes.core.rewrite.slices.rewrite_slice(slice) + assert result == nodes.ReversedNode(leaf) + + +def test_rewrite_filter_slice(leaf): + slice = nodes.SliceNode(leaf, None, 2) + result = bigframes.core.rewrite.slices.rewrite_slice(slice) + assert list(result.fields) == list(leaf.fields) + assert isinstance(result.child, nodes.FilterNode)