Skip to content

Commit ee7e81d

Browse files
committed
Avoid unnecessary copies of data in memory
1 parent a19c6b9 commit ee7e81d

File tree

1 file changed

+23
-28
lines changed

1 file changed

+23
-28
lines changed

src/mdio/segy/_workers.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from mdio.constants import UINT32_MAX
2323
from mdio.schemas.v1.stats import CenteredBinHistogram
2424
from mdio.schemas.v1.stats import SummaryStatistics
25+
from xarray import Variable
2526

2627

2728
class SegyFileArguments(TypedDict):
@@ -117,40 +118,34 @@ def trace_worker( # noqa: PLR0913
117118
ds_to_write = dataset[[data_variable_name, "headers"]]
118119
ds_to_write = ds_to_write.reset_coords()
119120

120-
try:
121-
# Create temporary array for headers with the correct shape
122-
tmp_headers = np.zeros(not_null.shape, dtype=ds_to_write["headers"].dtype)
123-
tmp_headers[not_null] = traces.header
124-
ds_to_write["headers"].data[:] = tmp_headers
125-
except Exception as e:
126-
print(f"Error writing headers: {e}")
127-
print(f"not_null.shape: {not_null.shape}")
128-
print(f"traces.header.shape: {traces.header.shape}")
129-
print(f"ds_to_write['headers'].data.shape: {ds_to_write['headers'].data.shape}")
130-
raise e
121+
# Create temporary array for headers with the correct shape
122+
tmp_headers = np.zeros(not_null.shape, dtype=ds_to_write["headers"].dtype)
123+
tmp_headers[not_null] = traces.header
124+
# Create a new Variable object to avoid copying the temporary array
125+
ds_to_write["headers"] = Variable(
126+
ds_to_write["headers"].dims,
127+
tmp_headers,
128+
attrs = ds_to_write["headers"].attrs,
129+
)
131130

132131
else:
133132
ds_to_write = dataset[[data_variable_name]]
134133
ds_to_write = ds_to_write.reset_coords()
135134

136-
try:
137-
# Get the sample dimension size from the data variable itself
138-
sample_dim_size = ds_to_write[data_variable_name].shape[-1]
139-
tmp_samples = np.zeros(not_null.shape + (sample_dim_size,), dtype=ds_to_write[data_variable_name].dtype)
135+
# Get the sample dimension size from the data variable itself
136+
sample_dim_size = ds_to_write[data_variable_name].shape[-1]
137+
tmp_samples = np.zeros(not_null.shape + (sample_dim_size,), dtype=ds_to_write[data_variable_name].dtype)
140138

141-
# Assign trace samples to the correct positions
142-
# We need to handle the fact that traces.sample is (num_traces, num_samples)
143-
# and we want to put it into positions where not_null is True
144-
tmp_samples[not_null] = traces.sample
145-
ds_to_write[data_variable_name].data[:] = tmp_samples
146-
except Exception as e:
147-
print(f"Error writing samples: {e}")
148-
print(f"not_null.shape: {not_null.shape}")
149-
print(f"traces.sample.shape: {traces.sample.shape}")
150-
print(f"ds_to_write[data_variable_name].data.shape: {ds_to_write[data_variable_name].data.shape}")
151-
print(f"not_null.sum(): {not_null.sum()}")
152-
print(f"len(traces.sample): {len(traces.sample)}")
153-
raise e
139+
# Assign trace samples to the correct positions
140+
# We need to handle the fact that traces.sample is (num_traces, num_samples)
141+
# and we want to put it into positions where not_null is True
142+
tmp_samples[not_null] = traces.sample
143+
# Create a new Variable object to avoid copying the temporary array
144+
ds_to_write[data_variable_name] = Variable(
145+
ds_to_write[data_variable_name].dims,
146+
tmp_samples,
147+
attrs = ds_to_write[data_variable_name].attrs,
148+
)
154149

155150
out_path = output_location.uri
156151
ds_to_write.to_zarr(out_path, region=region, mode="r+", write_empty_chunks=False, zarr_format=2)

0 commit comments

Comments
 (0)