Skip to content

Commit 8c94597

Browse files
committed
Refactor _workers.py to streamline variable handling, replace manual Variable creation with direct assignment, and resolve redundant imports.
1 parent 2ee2ab3 commit 8c94597

File tree

1 file changed

+13
-37
lines changed

1 file changed

+13
-37
lines changed

src/mdio/segy/_workers.py

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import numpy as np
1111
from segy import SegyFile
1212

13+
from mdio.schemas import ScalarType
14+
1315
if TYPE_CHECKING:
1416
from segy.arrays import HeaderArray
1517
from segy.config import SegySettings
@@ -19,7 +21,6 @@
1921

2022
from mdio.core.storage_location import StorageLocation
2123

22-
from xarray import Variable
2324

2425
from mdio.constants import UINT32_MAX
2526
from mdio.schemas.v1.dataset_serializer import _get_fill_value
@@ -112,62 +113,37 @@ def trace_worker( # noqa: PLR0913
112113
live_trace_indexes = grid_map[not_null].tolist()
113114
traces = segy_file.trace[live_trace_indexes]
114115

115-
hdr_key = "headers"
116+
header_key = "headers"
116117

117118
# Get subset of the dataset that has not yet been saved
118119
# The headers might not be present in the dataset
119-
# TODO(Dmitriy Repin): Check, should we overwrite the 'dataset' instead to save the memory
120-
# https://github.com/TGSAI/mdio-python/issues/584
121120
worker_variables = [data_variable_name]
122-
if hdr_key in dataset.data_vars: # Keeping the `if` here to allow for more worker configurations
123-
worker_variables.append(hdr_key)
121+
if header_key in dataset.data_vars: # Keeping the `if` here to allow for more worker configurations
122+
worker_variables.append(header_key)
124123

125124
ds_to_write = dataset[worker_variables]
126125

127-
if hdr_key in worker_variables:
126+
if header_key in worker_variables:
128127
# Create temporary array for headers with the correct shape
129128
# TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code. #noqa: TD003
130-
tmp_headers = np.zeros(not_null.shape, dtype=ds_to_write[hdr_key].dtype)
129+
tmp_headers = np.zeros_like(dataset[header_key])
131130
tmp_headers[not_null] = traces.header
132-
# Create a new Variable object to avoid copying the temporary array
133-
ds_to_write[hdr_key] = Variable(
134-
ds_to_write[hdr_key].dims,
135-
tmp_headers,
136-
attrs=ds_to_write[hdr_key].attrs,
137-
encoding=ds_to_write[hdr_key].encoding, # Not strictly necessary, but safer than not doing it.
138-
)
139-
140-
# Get the sample dimension size from the data variable itself
141-
sample_dim_size = ds_to_write[data_variable_name].shape[-1]
142-
tmp_samples = np.full(
143-
not_null.shape + (sample_dim_size,),
144-
dtype=ds_to_write[data_variable_name].dtype,
145-
fill_value=_get_fill_value(ds_to_write[data_variable_name].dtype),
146-
)
131+
ds_to_write[header_key][:] = tmp_headers
147132

148-
# Assign trace samples to the correct positions
149-
# We need to handle the fact that traces.sample is (num_traces, num_samples)
150-
# and we want to put it into positions where not_null is True
133+
data_variable = ds_to_write[data_variable_name]
134+
fill_value = _get_fill_value(ScalarType(data_variable.dtype.name))
135+
tmp_samples = np.full_like(data_variable, fill_value=fill_value)
151136
tmp_samples[not_null] = traces.sample
152-
# Create a new Variable object to avoid copying the temporary array
153-
ds_to_write[data_variable_name] = Variable(
154-
ds_to_write[data_variable_name].dims,
155-
tmp_samples,
156-
attrs=ds_to_write[data_variable_name].attrs,
157-
encoding=ds_to_write[data_variable_name].encoding, # Not strictly necessary, but safer than not doing it.
158-
)
137+
ds_to_write[data_variable_name][:] = tmp_samples
159138

160-
out_path = output_location.uri
161-
ds_to_write.to_zarr(out_path, region=region, mode="r+", write_empty_chunks=False, zarr_format=2)
139+
ds_to_write.to_zarr(output_location.uri, region=region, mode="r+", write_empty_chunks=False, zarr_format=2)
162140

163141
histogram = CenteredBinHistogram(bin_centers=[], counts=[])
164142
return SummaryStatistics(
165143
count=traces.sample.size,
166144
min=traces.sample.min(),
167145
max=traces.sample.max(),
168146
sum=traces.sample.sum(),
169-
# TODO(Altay): Look at how to do the sum squares statistic correctly
170-
# https://github.com/TGSAI/mdio-python/issues/581
171147
sum_squares=(traces.sample**2).sum(),
172148
histogram=histogram,
173149
)

0 commit comments

Comments
 (0)