Skip to content

Commit 8451ab4

Browse files
tcyameterstick-copybara
authored andcommitted
Fix naming issue for MetricList. name_tmpl wasn't respected in compute_on_sql when MetricList only has one child because its result is squeezed to a number in an intermediate step so the column name is lost.
PiperOrigin-RevId: 771168986
1 parent 55e0193 commit 8451ab4

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

metrics.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,7 @@ def __init__(self,
11151115
tmpl = name_tmpl or 'MetricList({})'
11161116
if len(children) == 1:
11171117
name = children[0].name
1118+
name = name_tmpl.format(name) if name_tmpl else name
11181119
else:
11191120
name = tmpl.format(', '.join(m.name for m in children))
11201121
super(MetricList, self).__init__(name, children, where, name_tmpl)
@@ -1154,8 +1155,11 @@ def compute_slices(self, df, split_by=None):
11541155
def compute_on_children(self, children, split_by):
11551156
if isinstance(children, list):
11561157
children = self.to_dataframe(children)
1157-
if self.name_tmpl:
1158-
children.columns = [self.name_tmpl.format(c) for c in children.columns]
1158+
if isinstance(children, pd.DataFrame):
1159+
if self.name_tmpl:
1160+
children.columns = [self.name_tmpl.format(c) for c in children.columns]
1161+
elif not isinstance(children, pd.Series):
1162+
children = pd.DataFrame({self.name: [children]})
11591163
if self.columns:
11601164
if len(children.columns) != len(self.columns):
11611165
raise ValueError(
@@ -1262,8 +1266,7 @@ def compute_children_sql(self, table, split_by, execute, mode=None):
12621266
return_dataframe=self.children_return_dataframe,
12631267
)
12641268
)
1265-
# When there is only one child, returns the result of the child.
1266-
return children[0] if len(self.children) == 1 else children
1269+
return children
12671270

12681271
def get_sql_and_with_clause(self, table, split_by, global_filter, indexes,
12691272
local_filter, with_data):

metrics_test.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Tests for Metrics."""
1514

1615
from __future__ import absolute_import
1716
from __future__ import division
@@ -645,7 +644,7 @@ def test_duplicate_column_names(self):
645644
testing.assert_frame_equal(output, expected)
646645

647646

648-
class TestMetricList(absltest.TestCase):
647+
class TestMetricList(parameterized.TestCase):
649648

650649
def test_return_list(self):
651650
df = pd.DataFrame({'X': [0, 1, 2, 3]})
@@ -675,17 +674,26 @@ def test_return_df(self):
675674
}, columns=['sum(X)', 'mean(X)'])
676675
testing.assert_frame_equal(output, expected)
677676

678-
def test_with_name_tmpl(self):
677+
@parameterized.parameters(
678+
([metrics.Sum('X')],), ([metrics.Sum('X'), metrics.Mean('X')],)
679+
)
680+
def test_name_tmpl(self, children):
679681
df = pd.DataFrame({'X': [0, 1, 2, 3]})
680-
ms = [metrics.Sum('X'), metrics.Mean('X')]
681-
m = metrics.MetricList(ms, name_tmpl='a{}b')
682+
m = metrics.MetricList(children, name_tmpl='a{}b')
682683
output = m.compute_on(df)
683-
expected = pd.DataFrame(
684-
data={
685-
'asum(X)b': [6],
686-
'amean(X)b': [1.5]
687-
},
688-
columns=['asum(X)b', 'amean(X)b'])
684+
expected = metrics.MetricList(children).compute_on(df)
685+
expected.columns = [f'a{c}b' for c in expected.columns]
686+
testing.assert_frame_equal(output, expected)
687+
688+
@parameterized.parameters(
689+
([metrics.Sum('X')],), ([metrics.Sum('X'), metrics.Mean('X')],)
690+
)
691+
def test_nested_name_tmpl(self, children):
692+
df = pd.DataFrame({'X': [0, 1, 2, 3]})
693+
m = metrics.MetricList(children, name_tmpl='a {}')
694+
m = metrics.MetricList(m, name_tmpl='b {}')
695+
output = m.compute_on(df)
696+
expected = metrics.MetricList(children, name_tmpl='b a {}').compute_on(df)
689697
testing.assert_frame_equal(output, expected)
690698

691699
def test_operations(self):

0 commit comments

Comments
 (0)