Skip to content

Commit 6629e66

Browse files
authored
chore: add dry_run parameter to _read_gbq_colab (#1721)
1 parent c3c830c commit 6629e66

File tree

5 files changed

+83
-51
lines changed

5 files changed

+83
-51
lines changed

bigframes/session/__init__.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -477,14 +477,34 @@ def _register_object(
477477
):
478478
self._objects.append(weakref.ref(object))
479479

480+
@overload
480481
def _read_gbq_colab(
481482
self,
482483
query: str,
483-
# TODO: Add a callback parameter that takes some kind of Event object.
484-
# TODO: Add dry_run parameter.
485484
*,
486485
pyformat_args: Optional[Dict[str, Any]] = None,
486+
dry_run: Literal[False] = ...,
487487
) -> dataframe.DataFrame:
488+
...
489+
490+
@overload
491+
def _read_gbq_colab(
492+
self,
493+
query: str,
494+
*,
495+
pyformat_args: Optional[Dict[str, Any]] = None,
496+
dry_run: Literal[True] = ...,
497+
) -> pandas.Series:
498+
...
499+
500+
def _read_gbq_colab(
501+
self,
502+
query: str,
503+
# TODO: Add a callback parameter that takes some kind of Event object.
504+
*,
505+
pyformat_args: Optional[Dict[str, Any]] = None,
506+
dry_run: bool = False,
507+
) -> Union[dataframe.DataFrame, pandas.Series]:
488508
"""A version of read_gbq that has the necessary default values for use in colab integrations.
489509
490510
This includes, no ordering, no index, no progress bar, always use string
@@ -501,23 +521,21 @@ def _read_gbq_colab(
501521
None, this function always assumes {var} refers to a variable
502522
that is supposed to be supplied in this dictionary.
503523
"""
504-
# TODO: Allow for a table ID to avoid queries like with read_gbq?
505-
506524
if pyformat_args is None:
507525
pyformat_args = {}
508526

509-
# TODO: move this to read_gbq_query if/when we expose this feature
510-
# beyond in _read_gbq_colab.
511527
query = bigframes.core.pyformat.pyformat(
512528
query,
513529
pyformat_args=pyformat_args,
530+
# TODO: add dry_run parameter to avoid API calls for data in pyformat_args
514531
)
515532

516533
return self._loader.read_gbq_query(
517534
query=query,
518535
index_col=bigframes.enums.DefaultIndexKind.NULL,
519536
api_name="read_gbq_colab",
520537
force_total_order=False,
538+
dry_run=typing.cast(Union[Literal[False], Literal[True]], dry_run),
521539
)
522540

523541
@overload

bigframes/session/dry_runs.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,34 +101,38 @@ def get_query_stats(
101101

102102
job_api_repr = copy.deepcopy(query_job._properties)
103103

104-
job_ref = job_api_repr["jobReference"]
104+
# jobReference might not be populated for "job optional" queries.
105+
job_ref = job_api_repr.get("jobReference", {})
105106
for key, val in job_ref.items():
106107
index.append(key)
107108
values.append(val)
108109

110+
configuration = job_api_repr.get("configuration", {})
109111
index.append("jobType")
110-
values.append(job_api_repr["configuration"]["jobType"])
112+
values.append(configuration.get("jobType", None))
111113

112-
query_config = job_api_repr["configuration"]["query"]
114+
query_config = configuration.get("query", {})
113115
for key in ("destinationTable", "useLegacySql"):
114116
index.append(key)
115-
values.append(query_config.get(key))
117+
values.append(query_config.get(key, None))
116118

117-
query_stats = job_api_repr["statistics"]["query"]
119+
statistics = job_api_repr.get("statistics", {})
120+
query_stats = statistics.get("query", {})
118121
for key in (
119122
"referencedTables",
120123
"totalBytesProcessed",
121124
"cacheHit",
122125
"statementType",
123126
):
124127
index.append(key)
125-
values.append(query_stats.get(key))
128+
values.append(query_stats.get(key, None))
126129

130+
creation_time = statistics.get("creationTime", None)
127131
index.append("creationTime")
128132
values.append(
129-
pandas.Timestamp(
130-
job_api_repr["statistics"]["creationTime"], unit="ms", tz="UTC"
131-
)
133+
pandas.Timestamp(creation_time, unit="ms", tz="UTC")
134+
if creation_time is not None
135+
else None
132136
)
133137

134138
return pandas.Series(values, index=index)

bigframes/testing/mocks.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import copy
1516
import datetime
1617
from typing import Optional, Sequence
1718
import unittest.mock as mock
@@ -78,11 +79,14 @@ def create_bigquery_session(
7879
type(table).num_rows = mock.PropertyMock(return_value=1000000000)
7980
bqclient.get_table.return_value = table
8081

82+
queries = []
8183
job_configs = []
8284

8385
def query_mock(query, *args, job_config=None, **kwargs):
84-
job_configs.append(job_config)
86+
queries.append(query)
87+
job_configs.append(copy.deepcopy(job_config))
8588
query_job = mock.create_autospec(google.cloud.bigquery.QueryJob)
89+
query_job._properties = {}
8690
type(query_job).destination = mock.PropertyMock(
8791
return_value=anonymous_dataset.table("test_table"),
8892
)
@@ -100,7 +104,8 @@ def query_mock(query, *args, job_config=None, **kwargs):
100104
existing_query_and_wait = bqclient.query_and_wait
101105

102106
def query_and_wait_mock(query, *args, job_config=None, **kwargs):
103-
job_configs.append(job_config)
107+
queries.append(query)
108+
job_configs.append(copy.deepcopy(job_config))
104109
if query.startswith("SELECT CURRENT_TIMESTAMP()"):
105110
return iter([[datetime.datetime.now()]])
106111
else:
@@ -118,6 +123,7 @@ def query_and_wait_mock(query, *args, job_config=None, **kwargs):
118123
session._bq_connection_manager = mock.create_autospec(
119124
bigframes.clients.BqConnectionManager, instance=True
120125
)
126+
session._queries = queries # type: ignore
121127
session._job_configs = job_configs # type: ignore
122128
return session
123129

tests/system/small/session/test_read_gbq_colab.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_read_gbq_colab_to_pandas_batches_preserves_order_by(maybe_ordered_sessi
4747
def test_read_gbq_colab_includes_formatted_scalars(session):
4848
pyformat_args = {
4949
"some_integer": 123,
50-
"some_string": "This could be dangerous, but we esape it",
50+
"some_string": "This could be dangerous, but we escape it",
5151
# This is not a supported type, but ignored if not referenced.
5252
"some_object": object(),
5353
}
@@ -66,39 +66,7 @@ def test_read_gbq_colab_includes_formatted_scalars(session):
6666
{
6767
"some_integer": pandas.Series([123], dtype=pandas.Int64Dtype()),
6868
"some_string": pandas.Series(
69-
["This could be dangerous, but we esape it"],
70-
dtype="string[pyarrow]",
71-
),
72-
"escaped": pandas.Series(["{escaped}"], dtype="string[pyarrow]"),
73-
}
74-
),
75-
)
76-
77-
78-
def test_read_gbq_colab_includes_formatted_bigframes_dataframe(session):
79-
pyformat_args = {
80-
# TODO: put a bigframes DataFrame here.
81-
"some_integer": 123,
82-
"some_string": "This could be dangerous, but we esape it",
83-
# This is not a supported type, but ignored if not referenced.
84-
"some_object": object(),
85-
}
86-
df = session._read_gbq_colab(
87-
"""
88-
SELECT {some_integer} as some_integer,
89-
{some_string} as some_string,
90-
'{{escaped}}' as escaped
91-
""",
92-
pyformat_args=pyformat_args,
93-
)
94-
result = df.to_pandas()
95-
pandas.testing.assert_frame_equal(
96-
result,
97-
pandas.DataFrame(
98-
{
99-
"some_integer": pandas.Series([123], dtype=pandas.Int64Dtype()),
100-
"some_string": pandas.Series(
101-
["This could be dangerous, but we esape it"],
69+
["This could be dangerous, but we escape it"],
10270
dtype="string[pyarrow]",
10371
),
10472
"escaped": pandas.Series(["{escaped}"], dtype="string[pyarrow]"),

tests/unit/session/test_read_gbq_colab.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,39 @@ def test_read_gbq_colab_includes_label():
3030
label_values.extend(config.labels.values())
3131

3232
assert "read_gbq_colab" in label_values
33+
34+
35+
def test_read_gbq_colab_includes_formatted_values_in_dry_run():
36+
session = mocks.create_bigquery_session()
37+
38+
pyformat_args = {
39+
"some_integer": 123,
40+
"some_string": "This could be dangerous, but we escape it",
41+
# This is not a supported type, but ignored if not referenced.
42+
"some_object": object(),
43+
}
44+
_ = session._read_gbq_colab(
45+
"""
46+
SELECT {some_integer} as some_integer,
47+
{some_string} as some_string,
48+
'{{escaped}}' as escaped
49+
""",
50+
pyformat_args=pyformat_args,
51+
dry_run=True,
52+
)
53+
expected = """
54+
SELECT 123 as some_integer,
55+
'This could be dangerous, but we escape it' as some_string,
56+
'{escaped}' as escaped
57+
"""
58+
queries = session._queries # type: ignore
59+
configs = session._job_configs # type: ignore
60+
61+
for query, config in zip(queries, configs):
62+
if config is None:
63+
continue
64+
if config.dry_run:
65+
break
66+
67+
assert config.dry_run
68+
assert query.strip() == expected.strip()

0 commit comments

Comments
 (0)