Skip to content

Commit b0b4614

Browse files
tcyameterstick-copybara
authored andcommitted
Allow CUPED/PrePost to iterate on covariates.
PiperOrigin-RevId: 772163965
1 parent 2570fb0 commit b0b4614

File tree

5 files changed

+287
-22
lines changed

5 files changed

+287
-22
lines changed

meterstick_demo.ipynb

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13820,6 +13820,17 @@
1382013820
"Jackknife('cookie', MetricList((pct, prepost))).compute_on(df_prepost)"
1382113821
]
1382213822
},
13823+
{
13824+
"cell_type": "markdown",
13825+
"metadata": {
13826+
"id": "cD4H8FPkXoDs"
13827+
},
13828+
"source": [
13829+
"Note that\n",
13830+
"- When you pass multiple base `Metric`s to `CUPED` or `PrePostChange`, they adjust them one by one. Namely, `CUPED(column, baseline, [post1, post2], pre)` is equivalent to `MetricList([CUPED(column, baseline, post1, pre), CUPED(column, baseline, post2, pre)])`.\n",
13831+
"- When you pass multiple covariates, by default they are all used to do the adjustment. But if you set `multiple_covariates` to `False`, we'll zip the base `Metric`s and the covariates and create a list of single-covariate `CUPED`. Namely, `CUPED(child=[x1, x2], covariates=[y1, y2], multiple_covariates=False)` is equivalent to `MetricList([CUPED(child=x1, covariates=y1), CUPED(child=x2, covariates=y2)])`."
13832+
]
13833+
},
1382313834
{
1382413835
"cell_type": "markdown",
1382513836
"metadata": {

metrics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ def where(self, where):
271271
self.where_ = tuple(where)
272272

273273
def add_where(self, where):
274+
if where is None:
275+
return self
274276
where = [where] if isinstance(where, str) else list(where) or []
275277
if not self.where_:
276278
self.where = where
@@ -883,7 +885,7 @@ def get_equivalent(self, *auxiliary_cols):
883885
res = self.get_equivalent_without_filter(*auxiliary_cols) # pylint: disable=assignment-from-none
884886
if res:
885887
res.name = self.name
886-
res.where = self.where_
888+
res.add_where(self.where_)
887889
return res
888890

889891
def get_equivalent_without_filter(self, *auxiliary_cols):

0 commit comments

Comments
 (0)