Skip to content

Commit 9e1e41b

Browse files
glados-vermacopybara-github
authored andcommitted
Add from_dataframe method to Measurement to create multidim measurement from a DataFrame.
This also adds some symmetry with the existing to_dataframe method. PiperOrigin-RevId: 721549130
1 parent b8cfa7e commit 9e1e41b

File tree

2 files changed

+192
-33
lines changed

2 files changed

+192
-33
lines changed

openhtf/core/measurements.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,34 @@ def to_dataframe(self, columns: Any = None) -> Any:
480480

481481
return dataframe
482482

483+
def from_dataframe(self, dataframe: Any, metric_column: str):
484+
"""Convert a pandas DataFrame to a multi-dim measurement.
485+
486+
Args:
487+
dataframe: A pandas DataFrame. Dimensions for this multi-dim measurement
488+
need to match columns in the DataFrame (can be multi-index).
489+
metric_column: The column name of the metric to be measured.
490+
491+
Raises:
492+
ValueError: If dataframe is missing dimensions.
493+
"""
494+
if not isinstance(self._measured_value, DimensionedMeasuredValue):
495+
raise TypeError(
496+
'Only a dimensioned measurement can be set from a DataFrame'
497+
)
498+
dimension_labels = [d.name for d in self.dimensions]
499+
dimensioned_df = dataframe.reset_index()
500+
try:
501+
dimensioned_df.set_index(dimension_labels, inplace=True)
502+
except KeyError as e:
503+
raise ValueError('DataFrame is missing dimensions') from e
504+
if metric_column not in dimensioned_df.columns:
505+
raise ValueError(
506+
f'DataFrame does not have a column named {metric_column}'
507+
)
508+
for row_dimensions, row_metrics in dimensioned_df.iterrows():
509+
self.measured_value[row_dimensions] = row_metrics[metric_column]
510+
483511

484512
@attr.s(slots=True)
485513
class MeasuredValue(object):

test/core/measurements_test.py

Lines changed: 164 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from openhtf.core import measurements
2626
from examples import all_the_things
2727
from openhtf.util import test as htf_test
28+
import pandas
2829

2930
# Fields that are considered 'volatile' for record comparison.
3031
_VOLATILE_FIELDS = {
@@ -105,18 +106,24 @@ def test_precision(self):
105106
m.measured_value.set(1.2346)
106107
self.assertAlmostEqual(m.measured_value.value, 1.235)
107108

108-
m = htf.Measurement('meas_with_precision_and_dims').with_precision(
109-
3).with_dimensions('x')
109+
m = (
110+
htf.Measurement('meas_with_precision_and_dims')
111+
.with_precision(3)
112+
.with_dimensions('x')
113+
)
110114
m.measured_value[42] = 1.2346
111115
self.assertAlmostEqual(m.measured_value[42], 1.235)
112116

113117
def test_cache_same_object(self):
114118
m = htf.Measurement('measurement')
115119
basetypes0 = m.as_base_types()
116-
self.assertEqual({
117-
'name': 'measurement',
118-
'outcome': 'UNSET',
119-
}, basetypes0)
120+
self.assertEqual(
121+
{
122+
'name': 'measurement',
123+
'outcome': 'UNSET',
124+
},
125+
basetypes0,
126+
)
120127
basetypes1 = m.as_base_types()
121128
self.assertIs(basetypes0, basetypes1)
122129
m.measured_value.set(1)
@@ -127,7 +134,9 @@ def test_cache_same_object(self):
127134
'name': 'measurement',
128135
'outcome': 'PASS',
129136
'measured_value': 1,
130-
}, basetypes2)
137+
},
138+
basetypes2,
139+
)
131140
self.assertIs(basetypes0, basetypes2)
132141

133142
@htf_test.patch_plugs(user_mock='openhtf.plugs.user_input.UserInput')
@@ -142,52 +151,66 @@ def test_chaining_in_measurement_declarations(self, user_mock):
142151
@htf_test.yields_phases
143152
def test_measurements_with_dimensions(self):
144153
record = yield all_the_things.dimensions
145-
self.assertMeasured(record, 'dimensions', [
146-
(0, 1),
147-
(1, 2),
148-
(2, 4),
149-
(3, 8),
150-
(4, 16),
151-
])
152-
self.assertMeasured(record, 'lots_of_dims', [
153-
(1, 21, 101, 123),
154-
(2, 22, 102, 126),
155-
(3, 23, 103, 129),
156-
(4, 24, 104, 132),
157-
])
154+
self.assertMeasured(
155+
record,
156+
'dimensions',
157+
[
158+
(0, 1),
159+
(1, 2),
160+
(2, 4),
161+
(3, 8),
162+
(4, 16),
163+
],
164+
)
165+
self.assertMeasured(
166+
record,
167+
'lots_of_dims',
168+
[
169+
(1, 21, 101, 123),
170+
(2, 22, 102, 126),
171+
(3, 23, 103, 129),
172+
(4, 24, 104, 132),
173+
],
174+
)
158175

159176
@htf_test.yields_phases
160177
def test_validator_replacement(self):
161178
record = yield all_the_things.measures_with_args.with_args(
162-
minimum=2, maximum=4)
179+
minimum=2, maximum=4
180+
)
163181
self.assertMeasurementFail(record, 'replaced_min_only')
164182
self.assertMeasurementPass(record, 'replaced_max_only')
165183
self.assertMeasurementFail(record, 'replaced_min_max')
166184
record = yield all_the_things.measures_with_args.with_args(
167-
minimum=0, maximum=5)
185+
minimum=0, maximum=5
186+
)
168187
self.assertMeasurementPass(record, 'replaced_min_only')
169188
self.assertMeasurementPass(record, 'replaced_max_only')
170189
self.assertMeasurementPass(record, 'replaced_min_max')
171190
record = yield all_the_things.measures_with_args.with_args(
172-
minimum=-1, maximum=0)
191+
minimum=-1, maximum=0
192+
)
173193
self.assertMeasurementPass(record, 'replaced_min_only')
174194
self.assertMeasurementFail(record, 'replaced_max_only')
175195
self.assertMeasurementFail(record, 'replaced_min_max')
176196

177197
@htf_test.yields_phases
178198
def test_validator_replacement_marginal(self):
179199
record = yield all_the_things.measures_with_marginal_args.with_args(
180-
marginal_minimum=4, marginal_maximum=6)
200+
marginal_minimum=4, marginal_maximum=6
201+
)
181202
self.assertMeasurementMarginal(record, 'replaced_marginal_min_only')
182203
self.assertMeasurementNotMarginal(record, 'replaced_marginal_max_only')
183204
self.assertMeasurementMarginal(record, 'replaced_marginal_min_max')
184205
record = yield all_the_things.measures_with_marginal_args.with_args(
185-
marginal_minimum=1, marginal_maximum=2)
206+
marginal_minimum=1, marginal_maximum=2
207+
)
186208
self.assertMeasurementNotMarginal(record, 'replaced_marginal_min_only')
187209
self.assertMeasurementMarginal(record, 'replaced_marginal_max_only')
188210
self.assertMeasurementMarginal(record, 'replaced_marginal_min_max')
189211
record = yield all_the_things.measures_with_marginal_args.with_args(
190-
marginal_minimum=2, marginal_maximum=4)
212+
marginal_minimum=2, marginal_maximum=4
213+
)
191214
self.assertMeasurementNotMarginal(record, 'replaced_marginal_min_only')
192215
self.assertMeasurementNotMarginal(record, 'replaced_marginal_max_only')
193216
self.assertMeasurementNotMarginal(record, 'replaced_marginal_min_max')
@@ -196,12 +219,15 @@ def test_validator_replacement_marginal(self):
196219
def test_measurement_order(self):
197220
record = yield all_the_things.dimensions
198221
self.assertEqual(
199-
list(record.measurements.keys()), ['dimensions', 'lots_of_dims'])
222+
list(record.measurements.keys()), ['dimensions', 'lots_of_dims']
223+
)
200224
record = yield all_the_things.measures_with_args.with_args(
201-
minimum=2, maximum=4)
225+
minimum=2, maximum=4
226+
)
202227
self.assertEqual(
203228
list(record.measurements.keys()),
204-
['replaced_min_only', 'replaced_max_only', 'replaced_min_max'])
229+
['replaced_min_only', 'replaced_max_only', 'replaced_min_max'],
230+
)
205231

206232
@htf_test.yields_phases
207233
def test_bad_validation(self):
@@ -231,14 +257,19 @@ def test_to_dataframe__no_pandas(self):
231257
with self.assertRaises(RuntimeError):
232258
self.test_to_dataframe(units=True)
233259

234-
def test_to_dataframe(self, units=True):
260+
def _make_multidim_measurement(self, units=''):
235261
measurement = htf.Measurement('test_multidim')
236262
measurement.with_dimensions('ms', 'assembly', htf.Dimension('my_zone'))
263+
if units:
264+
measurement.with_units(units)
265+
return measurement
237266

267+
def test_to_dataframe(self, units=True):
238268
if units:
239-
measurement.with_units('°C')
269+
measurement = self._make_multidim_measurement('°C')
240270
measure_column_name = 'degree Celsius'
241271
else:
272+
measurement = self._make_multidim_measurement()
242273
measure_column_name = 'value'
243274

244275
for t in range(5):
@@ -260,6 +291,104 @@ def test_to_dataframe(self, units=True):
260291
def test_to_dataframe__no_units(self):
261292
self.test_to_dataframe(units=False)
262293

294+
def test_from_dataframe_raises_if_dimensions_missing_in_dataframe(self):
295+
measurement = self._make_multidim_measurement('°C')
296+
with self.assertRaisesRegex(ValueError, 'DataFrame is missing dimensions'):
297+
measurement.from_dataframe(
298+
pandas.DataFrame({
299+
'ms': [1, 2, 3],
300+
'my_zone': ['X', 'Y', 'Z'],
301+
'degree_celsius': [10, 20, 30],
302+
}),
303+
metric_column='degree_celsius',
304+
)
305+
306+
def test_from_dataframe_raises_if_metric_missing_in_dataframe(self):
307+
measurement = self._make_multidim_measurement('°C')
308+
with self.assertRaisesRegex(
309+
ValueError, 'DataFrame does not have a column named degree_celsius'
310+
):
311+
measurement.from_dataframe(
312+
pandas.DataFrame({
313+
'ms': [1, 2, 3],
314+
'assembly': ['A', 'B', 'C'],
315+
'my_zone': ['X', 'Y', 'Z'],
316+
'degrees_fahrenheit': [10, 20, 30],
317+
}),
318+
metric_column='degree_celsius',
319+
)
320+
321+
def test_from_flat_dataframe(self):
322+
measurement = self._make_multidim_measurement('°C')
323+
source_dataframe = pandas.DataFrame({
324+
'ms': [1, 2, 3],
325+
'assembly': ['A', 'B', 'C'],
326+
'my_zone': ['X', 'Y', 'Z'],
327+
'degree_celsius': [10, 20, 30],
328+
})
329+
measurement.from_dataframe(source_dataframe, metric_column='degree_celsius')
330+
measurement.outcome = measurements.Outcome.PASS
331+
self.assertEqual(measurement.measured_value[(1, 'A', 'X')], 10)
332+
self.assertEqual(measurement.measured_value[(2, 'B', 'Y')], 20)
333+
self.assertEqual(measurement.measured_value[(3, 'C', 'Z')], 30)
334+
pandas.testing.assert_frame_equal(
335+
measurement.to_dataframe().rename(
336+
columns={
337+
'ms': 'ms',
338+
'assembly': 'assembly',
339+
'my_zone': 'my_zone',
340+
# The metric column name comes from the unit.
341+
'degree Celsius': 'degree_celsius',
342+
}
343+
),
344+
source_dataframe,
345+
)
346+
347+
def test_from_multiindex_dataframe(self):
348+
measurement = self._make_multidim_measurement('°C')
349+
source_dataframe = pandas.DataFrame({
350+
'ms': [1, 2, 3],
351+
'assembly': ['A', 'B', 'C'],
352+
'my_zone': ['X', 'Y', 'Z'],
353+
'degree_celsius': [10, 20, 30],
354+
})
355+
source_dataframe.set_index(['ms', 'assembly', 'my_zone'], inplace=True)
356+
measurement.from_dataframe(source_dataframe, metric_column='degree_celsius')
357+
measurement.outcome = measurements.Outcome.PASS
358+
self.assertEqual(measurement.measured_value[(1, 'A', 'X')], 10)
359+
self.assertEqual(measurement.measured_value[(2, 'B', 'Y')], 20)
360+
self.assertEqual(measurement.measured_value[(3, 'C', 'Z')], 30)
361+
362+
def test_from_dataframe_with_extra_columns(self):
363+
measurement = self._make_multidim_measurement('°C')
364+
source_dataframe = pandas.DataFrame({
365+
'ms': [1, 2, 3],
366+
'assembly': ['A', 'B', 'C'],
367+
'my_zone': ['X', 'Y', 'Z'],
368+
'degree_celsius': [10, 20, 30],
369+
'degrees_fahrenheit': [11, 21, 31],
370+
})
371+
measurement.from_dataframe(source_dataframe, metric_column='degree_celsius')
372+
measurement.outcome = measurements.Outcome.PASS
373+
self.assertEqual(measurement.measured_value[(1, 'A', 'X')], 10)
374+
self.assertEqual(measurement.measured_value[(2, 'B', 'Y')], 20)
375+
self.assertEqual(measurement.measured_value[(3, 'C', 'Z')], 30)
376+
377+
def test_from_dataframe_with_duplicate_dimensions_overwrites(self):
378+
measurement = self._make_multidim_measurement('°C')
379+
source_dataframe = pandas.DataFrame({
380+
'ms': [1, 2, 3, 1],
381+
'assembly': ['A', 'B', 'C', 'A'],
382+
'my_zone': ['X', 'Y', 'Z', 'X'],
383+
'degree_celsius': [10, 20, 30, 11],
384+
})
385+
measurement.from_dataframe(source_dataframe, metric_column='degree_celsius')
386+
measurement.outcome = measurements.Outcome.PASS
387+
# Overwritten value.
388+
self.assertEqual(measurement.measured_value[(1, 'A', 'X')], 11)
389+
self.assertEqual(measurement.measured_value[(2, 'B', 'Y')], 20)
390+
self.assertEqual(measurement.measured_value[(3, 'C', 'Z')], 30)
391+
263392
def test_bad_validator(self):
264393
measurement = htf.Measurement('bad_measure')
265394
measurement.with_dimensions('a')
@@ -346,8 +475,10 @@ def test_multi_dimension_correct(self):
346475
try:
347476
measurement.measured_value[dimension_vals] = 42
348477
except measurements.InvalidDimensionsError:
349-
self.fail('measurement.DimensionedMeasuredValue.__setitem__ '
350-
'raised error unexpectedly.')
478+
self.fail(
479+
'measurement.DimensionedMeasuredValue.__setitem__ '
480+
'raised error unexpectedly.'
481+
)
351482

352483
def test_multi_dimension_not_enough_error(self):
353484
measurement = htf.Measurement('measure')

0 commit comments

Comments
 (0)