Skip to content

Commit a2c2728

Browse files
authored
Fix statistics calculations (TGSAI#679)
* fix stats calculation and tolerances * switch allclose assertion at attribute comparison * make values list * drop histogram from value assertion
1 parent 954c53b commit a2c2728

File tree

2 files changed

+25
-20
lines changed

2 files changed

+25
-20
lines changed

src/mdio/segy/_workers.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,15 @@ def trace_worker( # noqa: PLR0913
134134
ds_to_write = dataset[worker_variables]
135135

136136
if header_key in worker_variables:
137-
# Create temporary array for headers with the correct shape
138-
# TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code. #noqa: TD003
137+
# TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code
138+
# https://github.com/TGSAI/mdio-python/issues/584
139139
tmp_headers = np.zeros_like(dataset[header_key])
140140
tmp_headers[not_null] = traces.header
141141
# Create a new Variable object to avoid copying the temporary array
142142
# The ideal solution is to use `ds_to_write[header_key][:] = tmp_headers`
143143
# but Xarray appears to be copying memory instead of doing direct assignment.
144144
# TODO(BrianMichell): #614 Look into this further.
145+
# https://github.com/TGSAI/mdio-python/issues/584
145146
ds_to_write[header_key] = Variable(
146147
ds_to_write[header_key].dims,
147148
tmp_headers,
@@ -153,8 +154,9 @@ def trace_worker( # noqa: PLR0913
153154
fill_value = _get_fill_value(ScalarType(data_variable.dtype.name))
154155
tmp_samples = np.full_like(data_variable, fill_value=fill_value)
155156
tmp_samples[not_null] = traces.sample
156-
# Create a new Variable object to avoid copying the temporary array
157+
157158
# TODO(BrianMichell): #614 Look into this further.
159+
# https://github.com/TGSAI/mdio-python/issues/584
158160
ds_to_write[data_variable_name] = Variable(
159161
ds_to_write[data_variable_name].dims,
160162
tmp_samples,
@@ -164,12 +166,13 @@ def trace_worker( # noqa: PLR0913
164166

165167
to_mdio(ds_to_write, output_path=output_path, region=region, mode="r+")
166168

169+
nonzero_samples = np.ma.masked_values(traces.sample, 0, copy=False)
167170
histogram = CenteredBinHistogram(bin_centers=[], counts=[])
168171
return SummaryStatistics(
169-
count=traces.sample.size,
170-
min=traces.sample.min(),
171-
max=traces.sample.max(),
172-
sum=traces.sample.sum(),
173-
sum_squares=(traces.sample**2).sum(),
172+
count=nonzero_samples.count(),
173+
min=nonzero_samples.min(),
174+
max=nonzero_samples.max(),
175+
sum=nonzero_samples.sum(dtype="float64"),
176+
sum_squares=(np.ma.power(nonzero_samples, 2).sum(dtype="float64")),
174177
histogram=histogram,
175178
)

tests/integration/test_segy_roundtrip_teapot.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,18 @@ def test_variable_metadata(self, zarr_tmp: Path) -> None:
181181
"""Metadata reading tests."""
182182
ds = open_mdio(zarr_tmp)
183183
expected_attrs = {
184-
"count": 97354860,
185-
"sum": -8594.551666259766,
186-
"sumSquares": 40571291.6875,
184+
"count": 46854270,
185+
"sum": -8594.551589292674,
186+
"sumSquares": 40571285.42351971,
187187
"min": -8.375323295593262,
188188
"max": 7.723702430725098,
189189
"histogram": {"counts": [], "binCenters": []},
190190
}
191-
actual_attrs_json = json.loads(ds["amplitude"].attrs["statsV1"])
192-
assert actual_attrs_json == expected_attrs
191+
actual_attrs = json.loads(ds["amplitude"].attrs["statsV1"])
192+
assert actual_attrs.keys() == expected_attrs.keys()
193+
actual_attrs.pop("histogram")
194+
expected_attrs.pop("histogram")
195+
np.testing.assert_allclose(list(actual_attrs.values()), list(expected_attrs.values()))
193196

194197
def test_grid(self, zarr_tmp: Path, teapot_segy_spec: SegySpec) -> None:
195198
"""Test validating MDIO variables."""
@@ -237,23 +240,22 @@ def test_inline_reads(self, zarr_tmp: Path) -> None:
237240
"""Read and compare every 75 inlines' mean and std. dev."""
238241
ds = open_mdio(zarr_tmp)
239242
inlines = ds["amplitude"][::75, :, :]
240-
mean, std = inlines.mean(), inlines.std()
241-
npt.assert_allclose([mean, std], [1.0555277e-04, 6.0027051e-01])
243+
mean, std = inlines.mean(dtype="float64"), inlines.std(dtype="float64")
244+
npt.assert_allclose([mean, std], [0.00010555267, 0.60027058412]) # 11 precision
242245

243246
def test_crossline_reads(self, zarr_tmp: Path) -> None:
244247
"""Read and compare every 75 crosslines' mean and std. dev."""
245248
ds = open_mdio(zarr_tmp)
246249
xlines = ds["amplitude"][:, ::75, :]
247-
mean, std = xlines.mean(), xlines.std()
248-
249-
npt.assert_allclose([mean, std], [-5.0329847e-05, 5.9406823e-01])
250+
mean, std = xlines.mean(dtype="float64"), xlines.std(dtype="float64")
251+
npt.assert_allclose([mean, std], [-5.03298501828e-05, 0.59406807762]) # 11 precision
250252

251253
def test_zslice_reads(self, zarr_tmp: Path) -> None:
252254
"""Read and compare every 225 z-slices' mean and std. dev."""
253255
ds = open_mdio(zarr_tmp)
254256
slices = ds["amplitude"][:, :, ::225]
255-
mean, std = slices.mean(), slices.std()
256-
npt.assert_allclose([mean, std], [0.005236923, 0.61279935])
257+
mean, std = slices.mean(dtype="float64"), slices.std(dtype="float64")
258+
npt.assert_allclose([mean, std], [0.00523692339, 0.61279943571]) # 11 precision
257259

258260
@pytest.mark.dependency("test_3d_import")
259261
def test_3d_export(

0 commit comments

Comments
 (0)