Skip to content

Commit 8e0001c

Browse files
tcyameterstick-copybara
authored andcommitted
Allow to CREATE TEMP TABLE in SQL generation when necessary. Previously we assume RAND() in the WITH clause behave as if they are evaluated only once, but that's not always the case. In situation when that's not true, we need to CREATE TEMP TABLE to materialize the subqueries that have volatile functions, so that the same result is used in all places.
PiperOrigin-RevId: 772345114
1 parent 9b75153 commit 8e0001c

File tree

5 files changed

+196
-14
lines changed

5 files changed

+196
-14
lines changed

meterstick_demo.ipynb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18589,14 +18589,22 @@
1858918589
"source": [
1859018590
"#SQL\n",
1859118591
"\n",
18592-
"You can easily get SQL query for all built-in Metrics and Operations by calling `to_sql(sql_table, split_by)`.\n",
18592+
"You can easily get SQL query for all built-in Metrics and Operations by calling `to_sql(sql_table, split_by, create_tmp_table_for_volatile_fn=None)`.\n",
1859318593
"\n",
1859418594
"You can also directly execute the query by calling `compute_on_sql(sql_table, split_by, execute, melted)`,\n",
1859518595
"\n",
1859618596
"where `execute` is a function that can execute SQL queries. The return is very similar to compute_on().\n",
1859718597
"\n",
1859818598
"The dialect it uses is the [standard SQL](https://cloud.google.com/bigquery/docs/reference/standard-sql) in Google Cloud's BigQuery.\n",
1859918599
"\n",
18600+
"The choice of `create_tmp_table_for_volatile_fn` depends on your SQL engine. If query\n",
18601+
"```\n",
18602+
"WITH T AS (SELECT RAND() AS r)\n",
18603+
"SELECT t1.r - t2.r AS d\n",
18604+
"FROM T t1 CROSS JOIN T t2\n",
18605+
"```\n",
18606+
"does NOT always return 0 on your engine, set `create_tmp_table_for_volatile_fn` to `True`.\n",
18607+
"\n",
1860018608
"Additionally, `compute_on_sql` also takes a `mode` arg. It can be `None` (default and recommended), `'mixed'` or `'magic'`. The mode controls how we split the computation between SQL and Python. For example, for a Metric with descendants, we can compute everything in SQL (if applicable), or the children in SQL and the parent in Python, or grandchildren in SQL and the rest in Python. The default `None` mode maximizes the SQL usage, namely, everything can be computed in SQL is computed in SQL. The `mixed` mode does the opposite. It minimizes the SQL usage, namely, only leaf Metrics are computed in SQL. The advantage of the `sql` mode is that SQL is usually faster and can handle larger data than Python. On the other hand, as all the `Metric`s computed in Python will be cached, the `mixed` mode will cache all levels of `Metric`s in the `Metric` tree. As a result, if you have a complex `Metric` that has many duplicated leaf `Metric`s, the `mixed` mode could be faster.\n",
1860118609
"\n",
1860218610
"There is another `magic` mode that only applies to `Model`s. The mode computes sufficient statistics in SQL then use them to solve the coefficients in Python. It's faster then the regular modes when fitting `Model`s on large data.\n",

metrics.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,10 @@ def compute_on_beam(
9393
# pylint: enable=g-long-lambda
9494

9595

96-
def to_sql(table, split_by=None):
97-
return lambda metric: metric.to_sql(table, split_by)
96+
def to_sql(table, split_by=None, create_tmp_table_for_volatile_fn=None):
97+
return lambda metric: metric.to_sql(
98+
table, split_by, create_tmp_table_for_volatile_fn
99+
)
98100

99101

100102
# Classes we built so caching across instances can be enabled with confidence.
@@ -677,7 +679,25 @@ def to_series_or_number(self, df):
677679

678680
def compute_on_sql_sql_mode(self, table, split_by=None, execute=None):
679681
"""Executes the query from to_sql() and process the result."""
680-
query = self.to_sql(table, split_by)
682+
query = self.to_sql(table, split_by, False)
683+
# We try to avoid using CREATE TEMP TABLE when possible. It's only used when
684+
# - the query contains RAND();
685+
# - the execute doesn't evaluate RAND() only once in the WITH clause;
686+
# - ALLOW_TEMP_TABLE is True.
687+
if sql.ALLOW_TEMP_TABLE and 'RAND()' in str(query):
688+
query_with_tmp_table = self.to_sql(table, split_by, True)
689+
if str(query) != str(
690+
query_with_tmp_table
691+
) and not sql.rand_run_only_once_in_with_clause(execute):
692+
try:
693+
execute('CREATE OR REPLACE TEMP TABLE T AS (SELECT 42 AS ans);')
694+
sql.TEMP_TABLE_SUPPORTED = True
695+
query = self.to_sql(table, split_by, True)
696+
except Exception as exc:
697+
sql.TEMP_TABLE_SUPPORTED = False
698+
raise NotImplementedError from exc
699+
finally:
700+
sql.TEMP_TABLE_SUPPORTED = None
681701
res = execute(str(query))
682702
extra_idx = list(utils.get_extra_idx(self, return_superset=True))
683703
indexes = split_by + extra_idx if split_by else extra_idx
@@ -690,8 +710,39 @@ def compute_on_sql_sql_mode(self, table, split_by=None, execute=None):
690710
res.sort_values(split_by, kind='mergesort', inplace=True)
691711
return res
692712

693-
def to_sql(self, table, split_by: Optional[Union[Text, List[Text]]] = None):
694-
"""Generates SQL query for the metric."""
713+
def to_sql(
714+
self,
715+
table,
716+
split_by: Optional[Union[Text, List[Text]]] = None,
717+
create_tmp_table_for_volatile_fn=None,
718+
):
719+
"""Generates SQL query for the metric.
720+
721+
Args:
722+
table: The table or subquery we want to query from.
723+
split_by: The columns that we use to split the data.
724+
create_tmp_table_for_volatile_fn: When generating the query, we assume
725+
that volatile functions like RAND() in the WITH clause behave as if they
726+
are evaluated only once. Unfortunately, not all engines behave like
727+
that. In those cases, we need to CREATE TEMP TABLE to materialize the
728+
subqueries that have volatile functions, so that the same result is used
729+
in all places. An example is
730+
WITH T AS (SELECT RAND() AS r)
731+
SELECT t1.r - t2.r AS d
732+
FROM T t1 CROSS JOIN T t2.
733+
If it doesn't always evaluates to 0, then this arg should be True, and
734+
we will put all subqueries that
735+
1) have volatile functions and
736+
2) are referenced in the same query multiple times,
737+
into CREATE TEMP TABLE statements.
738+
Note that this arg has no effect if sql.ALLOW_TEMP_TABLE is False.
739+
When you use compute_on_sql or compute_on_beam, this arg is
740+
automatically decided based on your `execute` function.
741+
742+
Returns:
743+
The SQL query for the metric as a SQL instance, which is similar to a str.
744+
Calling str() on it will get the query in string.
745+
"""
695746
global_filter = utils.get_global_filter(self)
696747
indexes = sql.Columns(split_by).add(
697748
utils.get_extra_idx(self, return_superset=True)
@@ -708,6 +759,17 @@ def to_sql(self, table, split_by: Optional[Union[Text, List[Text]]] = None):
708759
global_filter, indexes,
709760
sql.Filters(), with_data)
710761
query.with_data = with_data
762+
create_tmp_table = (
763+
sql.ALLOW_TEMP_TABLE
764+
if create_tmp_table_for_volatile_fn is None
765+
else create_tmp_table_for_volatile_fn
766+
)
767+
if not create_tmp_table:
768+
return query
769+
# None means we don't know yet so we only check for False.
770+
if sql.TEMP_TABLE_SUPPORTED is False: # pylint: disable=g-bool-id-comparison
771+
raise NotImplementedError # to fall back to the mixed mode
772+
with_data.temp_tables = sql.get_temp_tables(with_data)
711773
return query
712774

713775
def get_sql_and_with_clause(self, table: sql.Datasource,

operations.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2277,23 +2277,27 @@ def compute_children_sql(self,
22772277
"""The return should be similar to compute_children()."""
22782278
raise NotImplementedError
22792279

2280-
def to_sql(self, table, split_by=None):
2280+
def to_sql(self, table, split_by=None, create_tmp_table_for_volatile_fn=None):
22812281
if not isinstance(self, (Jackknife, Bootstrap)):
22822282
raise NotImplementedError
22832283
split_by = [split_by] if isinstance(split_by, str) else list(split_by or [])
22842284
# If self is not root, this function won't be called.
22852285
self._is_root_node = True
22862286
if self.has_been_preaggregated or not self.can_precompute():
22872287
if not self.where:
2288-
return super(MetricWithCI, self).to_sql(table, split_by)
2288+
return super(MetricWithCI, self).to_sql(
2289+
table, split_by, create_tmp_table_for_volatile_fn
2290+
)
22892291
table = sql.Sql(None, table, self.where)
22902292
self_no_filter = copy.deepcopy(self)
22912293
self_no_filter.where = None
2292-
return self_no_filter.to_sql(table, split_by)
2294+
return self_no_filter.to_sql(
2295+
table, split_by, create_tmp_table_for_volatile_fn
2296+
)
22932297

22942298
expanded, _ = utils.get_fully_expanded_equivalent_metric_tree(self)
22952299
if self != expanded:
2296-
return expanded.to_sql(table, split_by)
2300+
return expanded.to_sql(table, split_by, create_tmp_table_for_volatile_fn)
22972301

22982302
expanded.where = None # The filter has been taken care of in preaggregation
22992303
expanded = utils.push_filters_to_leaf(expanded)
@@ -2322,7 +2326,7 @@ def to_sql(self, table, split_by=None):
23222326
equiv.unit = None
23232327
else:
23242328
equiv.has_local_filter = any([l.where for l in leaf])
2325-
return equiv.to_sql(preagg, split_by)
2329+
return equiv.to_sql(preagg, split_by, create_tmp_table_for_volatile_fn)
23262330

23272331
def get_sql_and_with_clause(
23282332
self, table, split_by, global_filter, indexes, local_filter, with_data

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "meterstick"
3-
version = "1.5.3"
3+
version = "1.5.4"
44
authors = [
55
{ name="Xunmo Yang", email="xunmo@google.com" },
66
{ name="Dennis Sun", email="dlsun@google.com" },

sql.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525

2626

2727
SAFE_DIVIDE = 'IF(({denom}) = 0, NULL, ({numer}) / ({denom}))'
28+
# If to use CREATE TEMP TABLE. Setting it to False disables CREATE TEMP TABLE
29+
# even when it's needed.
30+
ALLOW_TEMP_TABLE = True
31+
# If the engine supports CREATE TEMP TABLE
32+
TEMP_TABLE_SUPPORTED = None
2833

2934

3035
def is_compatible(sql0, sql1):
@@ -67,6 +72,80 @@ def add_suffix(alias):
6772
return alias + '_1'
6873

6974

75+
def rand_run_only_once_in_with_clause(execute):
76+
"""Check if the RAND() is only evaluated once in the WITH clause."""
77+
d = execute(
78+
'''WITH T AS (SELECT RAND() AS r)
79+
SELECT t1.r - t2.r AS d
80+
FROM T t1 CROSS JOIN T t2'''
81+
)
82+
return bool(d.iloc[0, 0] == 0)
83+
84+
85+
def dep_on_rand_table(query, rand_tables):
86+
"""Returns if a SQL query depends on any stochastic table in rand_tables."""
87+
for rand_table in rand_tables:
88+
if re.search(r'\b%s\b' % rand_table, str(query)):
89+
return True
90+
return False
91+
92+
93+
def get_temp_tables(with_data: 'Datasources'):
94+
"""Gets all the subquery tables that need to be materialized.
95+
96+
When generating the query, we assume that volatile functions like RAND() in
97+
the WITH clause behave as if they are evaluated only once. Unfortunately, not
98+
all engines behave like that. In those cases, we need to CREATE TEMP TABLE to
99+
materialize the subqueries that have volatile functions, so that the same
100+
result is used in all places. An example is
101+
WITH T AS (SELECT RAND() AS r)
102+
SELECT t1.r - t2.r AS d
103+
FROM T t1 CROSS JOIN T t2.
104+
If it doesn't always evaluates to 0, we need to create a temp table for T.
105+
A subquery needs to be materialized if
106+
1. it depends on any stochastic table
107+
(e.g. RAND()) and
108+
2. the random column is referenced in the same query multiple times.
109+
#2 is hard to check so we check if the stochastic table is referenced in the
110+
same query multiple times instead.
111+
An exception is the BootstrapRandomChoices table, which refers to a stochastic
112+
table twice but only one refers to the stochasic column, so we don't need to
113+
materialize it.
114+
This function finds all the subquery tables in the WITH clause that need to be
115+
materialized by
116+
1. finding all the stochastic tables,
117+
2. finding all the tables that depend, even indirectly, on a stochastic table,
118+
3. finding all the tables in #2 that are referenced in the same query multiple
119+
times.
120+
121+
Args:
122+
with_data: The with clause.
123+
124+
Returns:
125+
A set of table names that need to be materialized.
126+
"""
127+
tmp_tables = set()
128+
for rand_table in with_data:
129+
query = with_data[rand_table]
130+
if 'RAND' not in str(query):
131+
continue
132+
dep_on_rand = set([rand_table])
133+
for alias in with_data:
134+
if dep_on_rand_table(with_data[alias].from_data, dep_on_rand):
135+
dep_on_rand.add(alias)
136+
for t in dep_on_rand:
137+
from_data = with_data[t].from_data
138+
if isinstance(from_data, Join) and not t.startswith(
139+
'BootstrapRandomChoices'
140+
):
141+
if dep_on_rand_table(from_data.ds1, dep_on_rand) and dep_on_rand_table(
142+
from_data.ds2, dep_on_rand
143+
):
144+
tmp_tables.add(rand_table)
145+
break
146+
return tmp_tables
147+
148+
70149
def get_alias(c):
71150
return getattr(c, 'alias_raw', c)
72151

@@ -571,6 +650,7 @@ class Datasources(SqlComponents):
571650
def __init__(self, datasources=None):
572651
super(Datasources, self).__init__()
573652
self.children = collections.OrderedDict()
653+
self.temp_tables = set()
574654
self.add(datasources)
575655

576656
@property
@@ -676,6 +756,23 @@ def add(self, children: Union[Datasource, Iterable[Datasource]]):
676756
children.alias = add_suffix(alias)
677757
return self.add(children)
678758

759+
def add_temp_table(self, table: Union[str, 'Sql', Join, Datasource]):
760+
"""Marks alias and all its data dependencies as temp tables."""
761+
if isinstance(table, str):
762+
self.temp_tables.add(table)
763+
if table in self.children:
764+
self.add_temp_table(self.children[table])
765+
return
766+
if isinstance(table, Join):
767+
self.add_temp_table(table.ds1)
768+
self.add_temp_table(table.ds2)
769+
return
770+
if isinstance(table, Datasource):
771+
return self.add_temp_table(table.table)
772+
if isinstance(table, Sql):
773+
return self.add_temp_table(table.from_data)
774+
return self
775+
679776
def extend(self, other: 'Datasources'):
680777
"""Merge other to self. Adjust the query if a new alias is needed."""
681778
datasources = list(other.datasources)
@@ -691,7 +788,18 @@ def extend(self, other: 'Datasources'):
691788
return self
692789

693790
def __str__(self):
694-
return ',\n'.join((d.get_expression('WITH') for d in self.datasources if d))
791+
temp_tables = []
792+
with_tables = []
793+
for d in self.datasources:
794+
expression = d.get_expression('WITH')
795+
if d.alias in self.temp_tables:
796+
temp_tables.append(f'CREATE OR REPLACE TEMP TABLE {expression};')
797+
else:
798+
with_tables.append(expression)
799+
res = '\n'.join(temp_tables)
800+
if with_tables:
801+
res += '\nWITH\n' + ',\n'.join(with_tables)
802+
return res.strip()
695803

696804

697805
class Sql(SqlComponent):
@@ -766,7 +874,7 @@ def merge(self, other: 'Sql'):
766874
return True
767875

768876
def __str__(self):
769-
with_clause = 'WITH\n%s' % self.with_data if self.with_data else None
877+
with_clause = str(self.with_data) if self.with_data else None
770878
all_columns = self.all_columns or '*'
771879
select_clause = f'SELECT\n{all_columns}'
772880
from_clause = ('FROM %s'

0 commit comments

Comments
 (0)