Skip to content

Commit 0b4c2d7

Browse files
committed
test: add unit test for remap_variables
1 parent 209d0d4 commit 0b4c2d7

File tree

3 files changed

+181
-21
lines changed

3 files changed

+181
-21
lines changed
Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Google LLC
1+
# Copyright 2025 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -14,10 +14,9 @@
1414
import unittest.mock as mock
1515

1616
import google.cloud.bigquery
17+
import pytest
1718

1819
import bigframes.core as core
19-
import bigframes.core.nodes as nodes
20-
import bigframes.core.rewrite.slices
2120
import bigframes.core.schema
2221

2322
TABLE_REF = google.cloud.bigquery.TableReference.from_string("project.dataset.table")
@@ -31,27 +30,22 @@
3130
)
3231
FAKE_SESSION = mock.create_autospec(bigframes.Session, instance=True)
3332
type(FAKE_SESSION)._strictly_ordered = mock.PropertyMock(return_value=True)
34-
LEAF = core.ArrayValue.from_table(
35-
session=FAKE_SESSION,
36-
table=TABLE,
37-
schema=bigframes.core.schema.ArraySchema.from_bq_table(TABLE),
38-
).node
3933

4034

41-
def test_rewrite_noop_slice():
42-
slice = nodes.SliceNode(LEAF, None, None)
43-
result = bigframes.core.rewrite.slices.rewrite_slice(slice)
44-
assert result == LEAF
35+
@pytest.fixture
36+
def table():
37+
return TABLE
4538

4639

47-
def test_rewrite_reverse_slice():
48-
slice = nodes.SliceNode(LEAF, None, None, -1)
49-
result = bigframes.core.rewrite.slices.rewrite_slice(slice)
50-
assert result == nodes.ReversedNode(LEAF)
40+
@pytest.fixture
41+
def fake_session():
42+
return FAKE_SESSION
5143

5244

53-
def test_rewrite_filter_slice():
54-
slice = nodes.SliceNode(LEAF, None, 2)
55-
result = bigframes.core.rewrite.slices.rewrite_slice(slice)
56-
assert list(result.fields) == list(LEAF.fields)
57-
assert isinstance(result.child, nodes.FilterNode)
45+
@pytest.fixture
46+
def leaf(fake_session, table):
47+
return core.ArrayValue.from_table(
48+
session=fake_session,
49+
table=table,
50+
schema=bigframes.core.schema.ArraySchema.from_bq_table(table),
51+
).node
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import bigframes.core as core
16+
import bigframes.core.identifiers as identifiers
17+
import bigframes.core.nodes as nodes
18+
import bigframes.core.rewrite.identifiers as id_rewrite
19+
20+
21+
def test_remap_variables_single_node(leaf):
22+
node = leaf
23+
id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100))
24+
new_node, mapping = id_rewrite.remap_variables(node, id_generator)
25+
assert new_node is not node
26+
assert len(mapping) == 2
27+
assert set(mapping.keys()) == {f.id for f in node.fields}
28+
assert set(mapping.values()) == {
29+
identifiers.ColumnId("id_0"),
30+
identifiers.ColumnId("id_1"),
31+
}
32+
33+
34+
def test_remap_variables_projection(leaf):
35+
node = nodes.ProjectionNode(
36+
leaf,
37+
(
38+
(
39+
core.expression.DerefOp(leaf.fields[0].id),
40+
identifiers.ColumnId("new_col"),
41+
),
42+
),
43+
)
44+
id_generator = (identifiers.ColumnId(f"id_{i}") for i in range(100))
45+
new_node, mapping = id_rewrite.remap_variables(node, id_generator)
46+
assert new_node is not node
47+
assert len(mapping) == 3
48+
assert set(mapping.keys()) == {f.id for f in node.fields}
49+
assert set(mapping.values()) == {identifiers.ColumnId(f"id_{i}") for i in range(3)}
50+
51+
52+
def test_remap_variables_nested_join_stability(leaf, fake_session, table):
53+
# Create two more distinct leaf nodes
54+
leaf2_uncached = core.ArrayValue.from_table(
55+
session=fake_session,
56+
table=table,
57+
schema=leaf.schema,
58+
).node
59+
leaf2 = leaf2_uncached.remap_vars(
60+
{
61+
field.id: identifiers.ColumnId(f"leaf2_{field.id.name}")
62+
for field in leaf2_uncached.fields
63+
}
64+
)
65+
leaf3_uncached = core.ArrayValue.from_table(
66+
session=fake_session,
67+
table=table,
68+
schema=leaf.schema,
69+
).node
70+
leaf3 = leaf3_uncached.remap_vars(
71+
{
72+
field.id: identifiers.ColumnId(f"leaf3_{field.id.name}")
73+
for field in leaf3_uncached.fields
74+
}
75+
)
76+
77+
# Create a nested join: (leaf JOIN leaf2) JOIN leaf3
78+
inner_join = nodes.JoinNode(
79+
left_child=leaf,
80+
right_child=leaf2,
81+
conditions=(
82+
(
83+
core.expression.DerefOp(leaf.fields[0].id),
84+
core.expression.DerefOp(leaf2.fields[0].id),
85+
),
86+
),
87+
type="inner",
88+
propogate_order=False,
89+
)
90+
outer_join = nodes.JoinNode(
91+
left_child=inner_join,
92+
right_child=leaf3,
93+
conditions=(
94+
(
95+
core.expression.DerefOp(inner_join.fields[0].id),
96+
core.expression.DerefOp(leaf3.fields[0].id),
97+
),
98+
),
99+
type="inner",
100+
propogate_order=False,
101+
)
102+
103+
# Run remap_variables twice and assert stability
104+
id_generator1 = (identifiers.ColumnId(f"id_{i}") for i in range(100))
105+
new_node1, mapping1 = id_rewrite.remap_variables(outer_join, id_generator1)
106+
107+
id_generator2 = (identifiers.ColumnId(f"id_{i}") for i in range(100))
108+
new_node2, mapping2 = id_rewrite.remap_variables(outer_join, id_generator2)
109+
110+
assert new_node1 == new_node2
111+
assert mapping1 == mapping2
112+
113+
114+
def test_remap_variables_concat_self_stability(leaf):
115+
# Create a concat node with the same child twice
116+
node = nodes.ConcatNode(
117+
children=(leaf, leaf),
118+
output_ids=(
119+
identifiers.ColumnId("concat_a"),
120+
identifiers.ColumnId("concat_b"),
121+
),
122+
)
123+
124+
# Run remap_variables twice and assert stability
125+
id_generator1 = (identifiers.ColumnId(f"id_{i}") for i in range(100))
126+
new_node1, mapping1 = id_rewrite.remap_variables(node, id_generator1)
127+
128+
id_generator2 = (identifiers.ColumnId(f"id_{i}") for i in range(100))
129+
new_node2, mapping2 = id_rewrite.remap_variables(node, id_generator2)
130+
131+
assert new_node1 == new_node2
132+
assert mapping1 == mapping2
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import bigframes.core.nodes as nodes
15+
import bigframes.core.rewrite.slices
16+
17+
18+
def test_rewrite_noop_slice(leaf):
19+
slice = nodes.SliceNode(leaf, None, None)
20+
result = bigframes.core.rewrite.slices.rewrite_slice(slice)
21+
assert result == leaf
22+
23+
24+
def test_rewrite_reverse_slice(leaf):
25+
slice = nodes.SliceNode(leaf, None, None, -1)
26+
result = bigframes.core.rewrite.slices.rewrite_slice(slice)
27+
assert result == nodes.ReversedNode(leaf)
28+
29+
30+
def test_rewrite_filter_slice(leaf):
31+
slice = nodes.SliceNode(leaf, None, 2)
32+
result = bigframes.core.rewrite.slices.rewrite_slice(slice)
33+
assert list(result.fields) == list(leaf.fields)
34+
assert isinstance(result.child, nodes.FilterNode)

0 commit comments

Comments
 (0)