|
10 | 10 | import numpy as np |
11 | 11 | from segy import SegyFile |
12 | 12 |
|
| 13 | +from mdio.schemas import ScalarType |
| 14 | + |
13 | 15 | if TYPE_CHECKING: |
14 | 16 | from segy.arrays import HeaderArray |
15 | 17 | from segy.config import SegySettings |
|
19 | 21 |
|
20 | 22 | from mdio.core.storage_location import StorageLocation |
21 | 23 |
|
22 | | -from xarray import Variable |
23 | 24 |
|
24 | 25 | from mdio.constants import UINT32_MAX |
25 | 26 | from mdio.schemas.v1.dataset_serializer import _get_fill_value |
@@ -112,62 +113,37 @@ def trace_worker( # noqa: PLR0913 |
112 | 113 | live_trace_indexes = grid_map[not_null].tolist() |
113 | 114 | traces = segy_file.trace[live_trace_indexes] |
114 | 115 |
|
115 | | - hdr_key = "headers" |
| 116 | + header_key = "headers" |
116 | 117 |
|
117 | 118 | # Get subset of the dataset that has not yet been saved |
118 | 119 | # The headers might not be present in the dataset |
119 | | - # TODO(Dmitriy Repin): Check, should we overwrite the 'dataset' instead to save the memory |
120 | | - # https://github.com/TGSAI/mdio-python/issues/584 |
121 | 120 | worker_variables = [data_variable_name] |
122 | | - if hdr_key in dataset.data_vars: # Keeping the `if` here to allow for more worker configurations |
123 | | - worker_variables.append(hdr_key) |
| 121 | + if header_key in dataset.data_vars: # Keeping the `if` here to allow for more worker configurations |
| 122 | + worker_variables.append(header_key) |
124 | 123 |
|
125 | 124 | ds_to_write = dataset[worker_variables] |
126 | 125 |
|
127 | | - if hdr_key in worker_variables: |
| 126 | + if header_key in worker_variables: |
128 | 127 | # Create temporary array for headers with the correct shape |
129 | 128 | # TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code. #noqa: TD003 |
130 | | - tmp_headers = np.zeros(not_null.shape, dtype=ds_to_write[hdr_key].dtype) |
| 129 | + tmp_headers = np.zeros_like(dataset[header_key]) |
131 | 130 | tmp_headers[not_null] = traces.header |
132 | | - # Create a new Variable object to avoid copying the temporary array |
133 | | - ds_to_write[hdr_key] = Variable( |
134 | | - ds_to_write[hdr_key].dims, |
135 | | - tmp_headers, |
136 | | - attrs=ds_to_write[hdr_key].attrs, |
137 | | - encoding=ds_to_write[hdr_key].encoding, # Not strictly necessary, but safer than not doing it. |
138 | | - ) |
139 | | - |
140 | | - # Get the sample dimension size from the data variable itself |
141 | | - sample_dim_size = ds_to_write[data_variable_name].shape[-1] |
142 | | - tmp_samples = np.full( |
143 | | - not_null.shape + (sample_dim_size,), |
144 | | - dtype=ds_to_write[data_variable_name].dtype, |
145 | | - fill_value=_get_fill_value(ds_to_write[data_variable_name].dtype), |
146 | | - ) |
| 131 | + ds_to_write[header_key][:] = tmp_headers |
147 | 132 |
|
148 | | - # Assign trace samples to the correct positions |
149 | | - # We need to handle the fact that traces.sample is (num_traces, num_samples) |
150 | | - # and we want to put it into positions where not_null is True |
| 133 | + data_variable = ds_to_write[data_variable_name] |
| 134 | + fill_value = _get_fill_value(ScalarType(data_variable.dtype.name)) |
| 135 | + tmp_samples = np.full_like(data_variable, fill_value=fill_value) |
151 | 136 | tmp_samples[not_null] = traces.sample |
152 | | - # Create a new Variable object to avoid copying the temporary array |
153 | | - ds_to_write[data_variable_name] = Variable( |
154 | | - ds_to_write[data_variable_name].dims, |
155 | | - tmp_samples, |
156 | | - attrs=ds_to_write[data_variable_name].attrs, |
157 | | - encoding=ds_to_write[data_variable_name].encoding, # Not strictly necessary, but safer than not doing it. |
158 | | - ) |
| 137 | + ds_to_write[data_variable_name][:] = tmp_samples |
159 | 138 |
|
160 | | - out_path = output_location.uri |
161 | | - ds_to_write.to_zarr(out_path, region=region, mode="r+", write_empty_chunks=False, zarr_format=2) |
| 139 | + ds_to_write.to_zarr(output_location.uri, region=region, mode="r+", write_empty_chunks=False, zarr_format=2) |
162 | 140 |
|
163 | 141 | histogram = CenteredBinHistogram(bin_centers=[], counts=[]) |
164 | 142 | return SummaryStatistics( |
165 | 143 | count=traces.sample.size, |
166 | 144 | min=traces.sample.min(), |
167 | 145 | max=traces.sample.max(), |
168 | 146 | sum=traces.sample.sum(), |
169 | | - # TODO(Altay): Look at how to do the sum squares statistic correctly |
170 | | - # https://github.com/TGSAI/mdio-python/issues/581 |
171 | 147 | sum_squares=(traces.sample**2).sum(), |
172 | 148 | histogram=histogram, |
173 | 149 | ) |
0 commit comments