Skip to content

Commit 90fc97e

Browse files
committed
Clean up header value handler
1 parent 4542792 commit 90fc97e

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

src/mdio/segy/_workers.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,28 +111,29 @@ def trace_worker( # noqa: PLR0913
111111
live_trace_indexes = grid_map[not_null].tolist()
112112
traces = segy_file.trace[live_trace_indexes]
113113

114+
hdr_key = "headers"
115+
114116
# Get subset of the dataset that has not yet been saved
115117
# The headers might not be present in the dataset
116118
# TODO(Dmitriy Repin): Check, should we overwrite the 'dataset' instead to save the memory
117119
# https://github.com/TGSAI/mdio-python/issues/584
118-
if "headers" in dataset.data_vars:
119-
ds_to_write = dataset[[data_variable_name, "headers"]]
120-
ds_to_write = ds_to_write.reset_coords()
120+
worker_variables = [data_variable_name]
121+
if hdr_key in dataset.data_vars: # Keeping the `if` here to allow for more worker configurations
122+
worker_variables.append(hdr_key)
123+
124+
ds_to_write = dataset[worker_variables]
121125

126+
if hdr_key in worker_variables:
122127
# Create temporary array for headers with the correct shape
123-
tmp_headers = np.zeros(not_null.shape, dtype=ds_to_write["headers"].dtype)
128+
tmp_headers = np.zeros(not_null.shape, dtype=ds_to_write[hdr_key].dtype)
124129
tmp_headers[not_null] = traces.header
125130
# Create a new Variable object to avoid copying the temporary array
126-
ds_to_write["headers"] = Variable(
127-
ds_to_write["headers"].dims,
131+
ds_to_write[hdr_key] = Variable(
132+
ds_to_write[hdr_key].dims,
128133
tmp_headers,
129-
attrs=ds_to_write["headers"].attrs,
134+
attrs=ds_to_write[hdr_key].attrs,
130135
)
131136

132-
else:
133-
ds_to_write = dataset[[data_variable_name]]
134-
ds_to_write = ds_to_write.reset_coords()
135-
136137
# Get the sample dimension size from the data variable itself
137138
sample_dim_size = ds_to_write[data_variable_name].shape[-1]
138139
tmp_samples = np.zeros(not_null.shape + (sample_dim_size,), dtype=ds_to_write[data_variable_name].dtype)

0 commit comments

Comments
 (0)