Skip to content

Commit 2570fb0

Browse files
tcyameterstick-copybara
authored andcommitted
Allow the preperiod metric and postperiod metric in CUPED/PrePostChange to have same names.
PiperOrigin-RevId: 771202290
1 parent f1665cd commit 2570fb0

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

operations.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,10 @@ class Adjust(metrics.Metric):
799799

800800
def compute_slices(self, df, split_by: Optional[List[Text]] = None):
801801
child = df.iloc[:, :len_child]
802+
prefix = utils.get_unique_prefix(child)
803+
df.columns = list(child.columns) + [
804+
prefix + c for c in df.columns[len_child:]
805+
]
802806
covariate = df.iloc[:, len_child:]
803807
if len(covariate.columns) > 1:
804808
return super(Adjust, self).compute_slices(df, split_by)
@@ -826,7 +830,7 @@ def compute_children_sql(self, table, split_by, execute, mode=None):
826830
child = super(PrePostChange,
827831
self).compute_children_sql(table, split_by, execute, mode)
828832
covariates = child.iloc[:, -self.k_covariates:]
829-
child = child.iloc[:, :self.k_covariates]
833+
child = child.iloc[:, :-self.k_covariates]
830834
return self.adjust_value(child, covariates, split_by)
831835

832836
def get_change_raw_sql(
@@ -1045,6 +1049,10 @@ class Adjust(metrics.Metric):
10451049

10461050
def compute_slices(self, df, split_by: Optional[List[Text]] = None):
10471051
child = df.iloc[:, :len_child]
1052+
prefix = utils.get_unique_prefix(child)
1053+
df.columns = list(child.columns) + [
1054+
prefix + c for c in df.columns[len_child:]
1055+
]
10481056
covariate = df.iloc[:, len_child:]
10491057
if len(covariate.columns) > 1:
10501058
return super(Adjust, self).compute_slices(df, split_by)
@@ -1073,7 +1081,7 @@ def compute_children_sql(self, table, split_by, execute, mode=None):
10731081
child = super(CUPED, self).compute_children_sql(table, split_by, execute,
10741082
mode)
10751083
covariates = child.iloc[:, -self.k_covariates:]
1076-
child = child.iloc[:, :self.k_covariates]
1084+
child = child.iloc[:, :-self.k_covariates]
10771085
return self.adjust_value(child, covariates, split_by)
10781086

10791087
def get_change_raw_sql(

operations_test.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def test_nxx(self):
256256
testing.assert_frame_equal(output, expected)
257257

258258

259-
class PrePostChangeTests(absltest.TestCase):
259+
class PrePostChangeTests(parameterized.TestCase):
260260
n = 40
261261
df = pd.DataFrame({
262262
'x': np.random.choice(range(20), n),
@@ -428,7 +428,7 @@ def test_complex(self):
428428
testing.assert_frame_equal(output, expected)
429429

430430

431-
class CUPEDTests(absltest.TestCase):
431+
class CUPEDTests(parameterized.TestCase):
432432
n = 40
433433
df = pd.DataFrame({
434434
'x': np.random.choice(range(20), n),
@@ -2362,6 +2362,33 @@ def test_different_metrics_have_different_fingerprints(self):
23622362
)
23632363
self.assertLen(fingerprints, len(distinct_ops))
23642364

2365+
@parameterized.parameters([operations.CUPED, operations.PrePostChange])
2366+
def test_cuped_prepost_with_duplicate_names_one_covariate(self, op):
2367+
s = metrics.Sum('x')
2368+
cov = metrics.Sum('x', where='x>0.1', name='foo')
2369+
cov_dup = metrics.Sum('x', where='x>0.1')
2370+
jk = operations.Jackknife('cookie', confidence=0.9)
2371+
op_dup = op('grp', 1, s, cov_dup, 'grp4')
2372+
op = op('grp', 1, s, cov, 'grp4')
2373+
output = jk(op_dup).compute_on(self.df).display(return_formatted_df=True)
2374+
expected = jk(op).compute_on(self.df).display(return_formatted_df=True)
2375+
testing.assert_frame_equal(output, expected)
2376+
2377+
@parameterized.parameters([operations.CUPED, operations.PrePostChange])
2378+
def test_cuped_prepost_with_duplicate_names_multiple_covariates(self, op):
2379+
s = metrics.Sum('x', name='foo')
2380+
cov_dup = [
2381+
metrics.Sum('x', where='x>0.1', name='foo'),
2382+
metrics.Sum('y', name='foo'),
2383+
]
2384+
cov = [metrics.Sum('x', where='x>0.1'), metrics.Sum('y')]
2385+
jk = operations.Jackknife('cookie', confidence=0.9)
2386+
op_dup = op('grp', 1, s, cov_dup, 'grp4')
2387+
op = op('grp', 1, s, cov, 'grp4')
2388+
output = jk(op_dup).compute_on(self.df).display(return_formatted_df=True)
2389+
expected = jk(op).compute_on(self.df).display(return_formatted_df=True)
2390+
testing.assert_frame_equal(output, expected)
2391+
23652392

23662393
if __name__ == '__main__':
23672394
absltest.main()

0 commit comments

Comments
 (0)