Skip to content

Commit 9c36151

Browse files
tcyameterstick-copybara
authored andcommitted
Generalize the sql dialect to be configurable. This allows Meterstick to be used with other SQL dialects than GoogleSQL, which is requested in #245. Tested in https://colab.research.google.com/drive/1y3UigzEby1anMM3-vXocBx7V8LVblIAp?usp=sharing
PiperOrigin-RevId: 784820403
1 parent b843fed commit 9c36151

File tree

10 files changed

+736
-173
lines changed

10 files changed

+736
-173
lines changed

README.md

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ routine data analysis tasks. Please see [meterstick_demo.ipynb](https://colab.re
77

88
This is not an officially supported Google product.
99

10-
1110
## tl;dr
1211

1312
Modify the demo colab [notebook](https://colab.research.google.com/github/google/meterstick/blob/master/meterstick_demo.ipynb) and adapt it to your needs.
@@ -41,7 +40,6 @@ This calculates the percent change in conversion rate and bounce rate,
4140
relative to the control arm, for each country and device, together with
4241
95% confidence intervals based on jackknife standard errors.
4342

44-
4543
## Building Blocks of an Analysis Object
4644

4745
### Metrics
@@ -249,7 +247,6 @@ metrics for non-spam clicks you can add a `where` clause to the Metric or
249247
MetricList. This clause is a boolean expression which can be passed to pandas'
250248
[query() method](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html).
251249

252-
253250
```python
254251
sum_non_spam_clicks = Sum("Clicks", where="~IsSpam")
255252
MetricList([Sum("Clicks"), Sum("Conversions")], where="~IsSpam")
@@ -308,9 +305,7 @@ It can help you to sanity check complex Metrics.
308305

309306
You can get the SQL query for all built-in Metrics and Operations by calling
310307
`to_sql(sql_data_source, split_by)` on the Metric. `sql_data_source` could be a
311-
table or a subquery. The dialect it uses is the
312-
[standard SQL](https://cloud.google.com/bigquery/docs/reference/standard-sql)
313-
in Google Cloud's BigQuery. For example,
308+
table or a subquery. For example,
314309

315310
```python
316311
MetricList((Sum('X', where='Y > 0'), Sum('X'))).to_sql('table', 'grp')
@@ -338,6 +333,19 @@ function that can execute SQL query. The `mode` can be `None` or
338333
`'mixed'`. The former is recommended and computes things in SQL whenever
339334
possible while the latter only computes the leaf Metrics in SQL.
340335

336+
The default dialect it uses is GoogleSQL. You can use `set_dialect()` to choose
337+
other dialects. Currently we support
338+
339+
* PostgreSQL
340+
* MySQL and MariaDB
341+
* SQLite
342+
* Oracle
343+
* Microsoft SQL Server
344+
* Trino SQL
345+
346+
For other dialects, you can manually overwrite the default string templates at
347+
the top of `sql.py` file.
348+
341349
## Apache Beam
342350

343351
There is also a

diversity.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,22 @@
2626
class DiversityBase(operations.Distribution):
2727
"""Base class that captures shared logic of diversity Operations."""
2828

29-
def __init__(self, over, child, name_tmpl, additional_fingerprint_attrs=None):
29+
def __init__(
30+
self,
31+
over,
32+
child,
33+
name_tmpl,
34+
where,
35+
additional_fingerprint_attrs=None,
36+
**kwargs,
37+
):
3038
super(DiversityBase, self).__init__(
3139
over,
3240
child,
3341
name_tmpl,
42+
where=where,
3443
additional_fingerprint_attrs=additional_fingerprint_attrs,
44+
**kwargs
3545
)
3646
self.extra_index = []
3747

@@ -59,8 +69,8 @@ def to_dataframe(self, res):
5969
class HHI(DiversityBase):
6070
"""Herfindahl–Hirschman index of metric distribution."""
6171

62-
def __init__(self, over, child=None):
63-
super(HHI, self).__init__(over, child, 'HHI of {}')
72+
def __init__(self, over, child=None, where=None, **kwargs):
73+
super(HHI, self).__init__(over, child, 'HHI of {}', where, **kwargs)
6474

6575
def compute_on_children(self, child, split_by):
6676
dist = super(HHI, self).compute_on_children(child, split_by)
@@ -100,8 +110,8 @@ def get_sql_and_with_clause(
100110
class Entropy(DiversityBase):
101111
"""Entropy of metric distribution."""
102112

103-
def __init__(self, over, child=None):
104-
super(Entropy, self).__init__(over, child, 'Entropy of {}')
113+
def __init__(self, over, child=None, where=None, **kwargs):
114+
super(Entropy, self).__init__(over, child, 'Entropy of {}', where, **kwargs)
105115

106116
def compute_on_children(self, child, split_by):
107117
dist = super(Entropy, self).compute_on_children(child, split_by)
@@ -141,14 +151,24 @@ def get_sql_and_with_clause(
141151
class TopK(DiversityBase):
142152
"""The total share of the largest k contributors."""
143153

144-
def __init__(self, over, k, child=None, additional_fingerprint_attrs=None):
154+
def __init__(
155+
self,
156+
over,
157+
k,
158+
child=None,
159+
where=None,
160+
additional_fingerprint_attrs=None,
161+
**kwargs,
162+
):
145163
if not isinstance(k, int):
146164
raise ValueError('k must be an integer!')
147165
super(TopK, self).__init__(
148166
over,
149167
child,
150168
"Top-%s's share of {}" % k,
169+
where,
151170
['k'] + (additional_fingerprint_attrs or []),
171+
**kwargs,
152172
)
153173
self.k = k
154174

@@ -173,9 +193,14 @@ def get_sql_and_with_clause(
173193
1. Get the query for the Distribution of the child Metric.
174194
2. Keep all indexing/groupby columns unchanged.
175195
3. For all value columns, collect the top-k values into an array by
176-
ARRAY_AGG(val_col ORDER BY val_col DESC LIMIT k) AS val_arr.
177-
4. For all value columns, do 'SELECT SUM(x) FROM UNNEST(val_arr) AS x' to
178-
get the sum of the top-k values.
196+
ARRAY_AGG(val_col IGNORE NULLS ORDER BY val_col DESC LIMIT k) AS val_arr.
197+
Note that the ordering between number and NULLs varies by dialect so we
198+
use IGNORE NULLS.
199+
4. For all value columns, do
200+
'SELECT SUM(x) FROM UNNEST(val_arr) AS x WITH OFFSET AS i WHERE i < k'
201+
to get the sum of the top-k values. Note that the
202+
'WITH OFFSET AS i WHERE i < k' is redundant here but many external
203+
dialects don't support 'LIMIT k' in #3 so we need to do it in #4.
179204
180205
Args:
181206
table: The table we want to query from.
@@ -206,14 +231,14 @@ def get_sql_and_with_clause(
206231
continue
207232

208233
top_k_array_col = sql.Column(
209-
(c.alias, c.alias),
210-
'ARRAY_AGG({} ORDER BY {} DESC LIMIT %s)' % self.k,
234+
c.alias,
235+
sql.ARRAY_AGG_FN(c.alias, ascending=False, dropna=True, limit=self.k),
211236
)
212237
top_k_array_col.set_alias(c.alias_raw)
213238
top_k_array_columns.add(top_k_array_col)
214239
top_k_sum_col = sql.Column(
215-
top_k_array_col.alias,
216-
'(SELECT SUM(x) FROM UNNEST({}) AS x)',
240+
'(SELECT SUM(x) FROM'
241+
f' {sql.UNNEST_ARRAY_FN(top_k_array_col.alias, "x", "i", self.k)})',
217242
)
218243
top_k_sum_col.set_alias(self.name_tmpl.format(c.alias_raw))
219244
top_k_sum_columns.add(top_k_sum_col)
@@ -229,7 +254,13 @@ class Nxx(DiversityBase):
229254
"""The minimum number of contributors to achieve certain share."""
230255

231256
def __init__(
232-
self, over, share, child=None, additional_fingerprint_attrs=None
257+
self,
258+
over,
259+
share,
260+
child=None,
261+
where=None,
262+
additional_fingerprint_attrs=None,
263+
**kwargs,
233264
):
234265
if not 0 < share <= 1:
235266
raise ValueError('Share must be in (0, 1]!')
@@ -238,7 +269,9 @@ def __init__(
238269
child,
239270
'N(%s) of {}'
240271
% (int(100 * share) if (100 * share).is_integer() else 100 * share),
272+
where,
241273
['share'] + (additional_fingerprint_attrs or []),
274+
**kwargs,
242275
)
243276
self.share = share
244277

@@ -308,8 +341,7 @@ def get_sql_and_with_clause(
308341
cumsum_cols.add(cumsum_col)
309342

310343
nxx_col = sql.Column(
311-
cumsum_col.alias,
312-
'COUNTIF({} < %s) + 1' % self.share,
344+
cumsum_col.alias, sql.COUNTIF_FN('{} < %s' % self.share) + ' + 1'
313345
)
314346
nxx_col.set_alias(self.name_tmpl.format(c.alias_raw))
315347
nxx_cols.add(nxx_col)

meterstick_custom_metrics.ipynb

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,8 @@
7878
},
7979
{
8080
"cell_type": "code",
81-
"execution_count": 3,
81+
"execution_count": null,
8282
"metadata": {
83-
"executionInfo": {
84-
"elapsed": 52,
85-
"status": "ok",
86-
"timestamp": 1727469298213,
87-
"user": {
88-
"displayName": "",
89-
"userId": ""
90-
},
91-
"user_tz": 420
92-
},
9383
"id": "0UI9rAtZnBUG"
9484
},
9585
"outputs": [],
@@ -2690,17 +2680,17 @@
26902680
"execution_count": null,
26912681
"metadata": {
26922682
"executionInfo": {
2693-
"elapsed": 60,
2683+
"elapsed": 59,
26942684
"status": "ok",
2695-
"timestamp": 1684897961304,
2685+
"timestamp": 1752749803027,
26962686
"user": {
2697-
"displayName": "",
2698-
"userId": ""
2687+
"displayName": "Xunmo Yang",
2688+
"userId": "12474546967758012552"
26992689
},
27002690
"user_tz": 420
27012691
},
2702-
"id": "F9w6gugy3vYf",
2703-
"outputId": "a6853397-c2c2-49c3-a8db-294c15d18b39"
2692+
"id": "5dKXoRHZ_Xi4",
2693+
"outputId": "50aa82e5-e70c-48e3-ee48-cfe6ec7ab880"
27042694
},
27052695
"outputs": [
27062696
{
@@ -2715,11 +2705,10 @@
27152705
"SELECT\n",
27162706
" country,\n",
27172707
" SAFE_DIVIDE(sum_clicks, SUM(sum_clicks) OVER ()) AS Distribution_of_sum_clicks\n",
2718-
"FROM DistributionRaw\n",
2719-
"GROUP BY country, Distribution_of_sum_clicks"
2708+
"FROM DistributionRaw"
27202709
]
27212710
},
2722-
"execution_count": 194,
2711+
"execution_count": 18,
27232712
"metadata": {},
27242713
"output_type": "execute_result"
27252714
}
@@ -2783,12 +2772,12 @@
27832772
" table, indexes, global_filter, indexes, local_filter, with_data\n",
27842773
" )\n",
27852774
" child_table = sql.Datasource(child_sql, 'DistributionRaw')\n",
2786-
" # Always use the alias returned by with_data.add(), because if the with_data\n",
2787-
" # already holds a different table that also has 'DistributionRaw' as its\n",
2788-
" # alias, we'll use a different alias for the child_table, which is returned\n",
2789-
" # by with_data.add().\n",
2790-
" child_table_alias = with_data.add(child_table)\n",
2791-
" groupby = sql.Columns(indexes.aliases, distinct=True)\n",
2775+
" # Always use the alias returned by with_data.merge(), because if the\n",
2776+
" # with_data already holds a different table that also has 'DistributionRaw'\n",
2777+
" # as its alias, we'll use a different alias for the child_table, which is\n",
2778+
" # returned by with_data.merge().\n",
2779+
" child_table_alias = with_data.merge(child_table)\n",
2780+
" groupby = sql.Columns(indexes.aliases)\n",
27922781
" columns = sql.Columns()\n",
27932782
" for c in child_sql.columns:\n",
27942783
" if c.alias in groupby:\n",
@@ -3157,7 +3146,7 @@
31573146
},
31583147
{
31593148
"cell_type": "code",
3160-
"execution_count": 4,
3149+
"execution_count": null,
31613150
"metadata": {
31623151
"executionInfo": {
31633152
"elapsed": 53,

meterstick_demo.ipynb

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18595,7 +18595,16 @@
1859518595
"\n",
1859618596
"where `execute` is a function that can execute SQL queries. The return is very similar to compute_on().\n",
1859718597
"\n",
18598-
"The dialect it uses is the [standard SQL](https://cloud.google.com/bigquery/docs/reference/standard-sql) in Google Cloud's BigQuery.\n",
18598+
"The default dialect it uses is GoogleSQL. You can use `set_dialect()` to choose other dialects. Currently we support\n",
18599+
"\n",
18600+
"* PostgreSQL\n",
18601+
"* MySQL and MariaDB\n",
18602+
"* SQLite\n",
18603+
"* Oracle\n",
18604+
"* Microsoft SQL Server\n",
18605+
"* Trino SQL\n",
18606+
"\n",
18607+
"For other dialects, you can manually overwrite the default string templates at the top of `sql.py` file.\n",
1859918608
"\n",
1860018609
"The choice of `create_tmp_table_for_volatile_fn` depends on your SQL engine. If query\n",
1860118610
"```\n",
@@ -18615,17 +18624,17 @@
1861518624
{
1861618625
"metadata": {
1861718626
"executionInfo": {
18618-
"elapsed": 59,
18627+
"elapsed": 54,
1861918628
"status": "ok",
18620-
"timestamp": 1750186450282,
18629+
"timestamp": 1752474546796,
1862118630
"user": {
1862218631
"displayName": "Xunmo Yang",
1862318632
"userId": "12474546967758012552"
1862418633
},
1862518634
"user_tz": 420
1862618635
},
18627-
"id": "eoHY1kVlPbSL",
18628-
"outputId": "e4f42347-809c-492d-ec9a-a72103fd86ef"
18636+
"id": "U4Gef-jGn2SY",
18637+
"outputId": "835c7c07-c13f-44a7-886d-9382e204081a"
1862918638
},
1863018639
"cell_type": "code",
1863118640
"source": [
@@ -18637,13 +18646,13 @@
1863718646
"text/plain": [
1863818647
"SELECT\n",
1863918648
" grp,\n",
18640-
" SUM(IF(Y \u003e 0, X, 0)) AS sum_X,\n",
18649+
" SUM(CASE WHEN Y \u003e 0 THEN X ELSE 0 END) AS sum_X,\n",
1864118650
" SUM(X) AS sum_X_1\n",
1864218651
"FROM T\n",
1864318652
"GROUP BY grp"
1864418653
]
1864518654
},
18646-
"execution_count": 54,
18655+
"execution_count": 492,
1864718656
"metadata": {},
1864818657
"output_type": "execute_result"
1864918658
}
@@ -21198,8 +21207,7 @@
2119821207
"file_id": "1u9XmuUlA0TtGmERFV1cSY-4UWpXYIJWL",
2119921208
"timestamp": 1588129678918
2120021209
}
21201-
],
21202-
"toc_visible": true
21210+
]
2120321211
},
2120421212
"kernelspec": {
2120521213
"display_name": "Python 3",

0 commit comments

Comments
 (0)