Skip to content

Commit 99afc36

Browse files
BrianMichelltasansal
authored andcommitted
Linting
1 parent 807106b commit 99afc36

File tree

4 files changed

+94
-96
lines changed

4 files changed

+94
-96
lines changed

src/mdio/segy/_disaster_recovery_wrapper.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44

55
from typing import TYPE_CHECKING
66

7+
from segy.schema import Endianness
8+
from segy.transforms import ByteSwapTransform
9+
from segy.transforms import IbmFloatTransform
10+
711
if TYPE_CHECKING:
8-
import numpy as np
9-
from segy.file import SegyFile
10-
from segy.transforms import Transform, ByteSwapTransform, IbmFloatTransform
1112
from numpy.typing import NDArray
13+
from segy import SegyFile
14+
from segy.transforms import Transform
15+
from segy.transforms import TransformPipeline
16+
1217

1318
def _reverse_single_transform(data: NDArray, transform: Transform, endianness: Endianness) -> NDArray:
1419
"""Reverse a single transform operation."""
15-
from segy.schema import Endianness
16-
from segy.transforms import ByteSwapTransform
17-
from segy.transforms import IbmFloatTransform
18-
1920
if isinstance(transform, ByteSwapTransform):
2021
# Reverse the endianness conversion
2122
if endianness == Endianness.LITTLE:
@@ -24,20 +25,19 @@ def _reverse_single_transform(data: NDArray, transform: Transform, endianness: E
2425
reverse_transform = ByteSwapTransform(Endianness.BIG)
2526
return reverse_transform.apply(data)
2627

27-
elif isinstance(transform, IbmFloatTransform): # TODO: This seems incorrect...
28+
# TODO(BrianMichell): #0000 Do we actually need to worry about IBM/IEEE transforms here?
29+
if isinstance(transform, IbmFloatTransform):
2830
# Reverse IBM float conversion
2931
reverse_direction = "to_ibm" if transform.direction == "to_ieee" else "to_ieee"
3032
reverse_transform = IbmFloatTransform(reverse_direction, transform.keys)
3133
return reverse_transform.apply(data)
3234

33-
else:
34-
# For unknown transforms, return data unchanged
35-
return data
35+
# For unknown transforms, return data unchanged
36+
return data
37+
3638

3739
def get_header_raw_and_transformed(
38-
segy_file: SegyFile,
39-
indices: int | list[int] | NDArray | slice,
40-
do_reverse_transforms: bool = True
40+
segy_file: SegyFile, indices: int | list[int] | NDArray | slice, do_reverse_transforms: bool = True
4141
) -> tuple[NDArray | None, NDArray, NDArray]:
4242
"""Get both raw and transformed header data.
4343
@@ -54,15 +54,20 @@ def get_header_raw_and_transformed(
5454

5555
# Reverse transforms to get raw data
5656
if do_reverse_transforms:
57-
raw_headers = _reverse_transforms(transformed_headers, segy_file.header.transform_pipeline, segy_file.spec.endianness)
57+
raw_headers = _reverse_transforms(
58+
transformed_headers, segy_file.header.transform_pipeline, segy_file.spec.endianness
59+
)
5860
else:
5961
raw_headers = None
6062

6163
return raw_headers, transformed_headers, traces
6264

63-
def _reverse_transforms(transformed_data: NDArray, transform_pipeline, endianness: Endianness) -> NDArray:
65+
66+
def _reverse_transforms(
67+
transformed_data: NDArray, transform_pipeline: TransformPipeline, endianness: Endianness
68+
) -> NDArray:
6469
"""Reverse the transform pipeline to get raw data."""
65-
raw_data = transformed_data.copy() if hasattr(transformed_data, 'copy') else transformed_data
70+
raw_data = transformed_data.copy() if hasattr(transformed_data, "copy") else transformed_data
6671

6772
# Apply transforms in reverse order
6873
for transform in reversed(transform_pipeline.transforms):

src/mdio/segy/_workers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import numpy as np
1111
from segy import SegyFile
12-
from segy.indexing import merge_cat_file
1312

1413
from mdio.api.io import to_mdio
1514
from mdio.builder.schemas.dtype import ScalarType
@@ -82,6 +81,7 @@ def header_scan_worker(
8281

8382
return cast("HeaderArray", trace_header)
8483

84+
8585
def trace_worker( # noqa: PLR0913
8686
segy_kw: SegyFileArguments,
8787
output_path: UPath,
@@ -135,11 +135,12 @@ def trace_worker( # noqa: PLR0913
135135
if header_key in dataset.data_vars: # Keeping the `if` here to allow for more worker configurations
136136
worker_variables.append(header_key)
137137
if raw_header_key in dataset.data_vars:
138-
139138
do_reverse_transforms = True
140139
worker_variables.append(raw_header_key)
141140

142-
raw_headers, transformed_headers, traces = get_header_raw_and_transformed(segy_file, live_trace_indexes, do_reverse_transforms=do_reverse_transforms)
141+
raw_headers, transformed_headers, traces = get_header_raw_and_transformed(
142+
segy_file, live_trace_indexes, do_reverse_transforms=do_reverse_transforms
143+
)
143144
ds_to_write = dataset[worker_variables]
144145

145146
if header_key in worker_variables:
@@ -168,7 +169,6 @@ def trace_worker( # noqa: PLR0913
168169
encoding=ds_to_write[raw_header_key].encoding, # Not strictly necessary, but safer than not doing it.
169170
)
170171

171-
172172
del raw_headers # Manage memory
173173
data_variable = ds_to_write[data_variable_name]
174174
fill_value = _get_fill_value(ScalarType(data_variable.dtype.name))

src/mdio/segy/blocked_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,4 +280,4 @@ def to_segy(
280280

281281
non_consecutive_axes -= 1
282282

283-
return block_io_records
283+
return block_io_records

tests/unit/test_disaster_recovery_wrapper.py

Lines changed: 67 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
SAMPLES_PER_TRACE = 1501
2929

30+
3031
class TestDisasterRecoveryWrapper:
3132
"""Test cases for disaster recovery wrapper functionality."""
3233

@@ -51,17 +52,19 @@ def basic_segy_spec(self) -> SegySpec:
5152

5253
return spec.customize(trace_header_fields=header_fields)
5354

54-
@pytest.fixture(params=[
55-
{"endianness": Endianness.BIG, "data_format": 1, "name": "big_endian_ibm"},
56-
{"endianness": Endianness.BIG, "data_format": 5, "name": "big_endian_ieee"},
57-
{"endianness": Endianness.LITTLE, "data_format": 1, "name": "little_endian_ibm"},
58-
{"endianness": Endianness.LITTLE, "data_format": 5, "name": "little_endian_ieee"},
59-
])
60-
def segy_config(self, request) -> dict:
55+
@pytest.fixture(
56+
params=[
57+
{"endianness": Endianness.BIG, "data_format": 1, "name": "big_endian_ibm"},
58+
{"endianness": Endianness.BIG, "data_format": 5, "name": "big_endian_ieee"},
59+
{"endianness": Endianness.LITTLE, "data_format": 1, "name": "little_endian_ibm"},
60+
{"endianness": Endianness.LITTLE, "data_format": 5, "name": "little_endian_ieee"},
61+
]
62+
)
63+
def segy_config(self, request: pytest.FixtureRequest) -> dict:
6164
"""Parameterized fixture for different SEGY configurations."""
6265
return request.param
6366

64-
def create_test_segy_file(
67+
def create_test_segy_file( # noqa: PLR0913
6568
self,
6669
spec: SegySpec,
6770
num_traces: int,
@@ -119,7 +122,7 @@ def extract_header_bytes_from_file(
119122
self, segy_path: Path, trace_index: int, byte_start: int, byte_length: int
120123
) -> NDArray:
121124
"""Extract specific bytes from a trace header in the SEGY file."""
122-
with open(segy_path, "rb") as f:
125+
with segy_path.open("rb") as f:
123126
# Skip text header (3200 bytes) + binary header (400 bytes)
124127
header_offset = 3600
125128

@@ -164,18 +167,12 @@ def test_header_validation_configurations(
164167
for trace_idx in test_indices:
165168
# Get raw and transformed headers
166169
raw_headers, transformed_headers, traces = get_header_raw_and_transformed(
167-
segy_file=segy_file,
168-
indices=trace_idx,
169-
do_reverse_transforms=True
170+
segy_file=segy_file, indices=trace_idx, do_reverse_transforms=True
170171
)
171172

172173
# Extract bytes from disk for inline (bytes 189-192) and crossline (bytes 193-196)
173-
inline_bytes_disk = self.extract_header_bytes_from_file(
174-
segy_path, trace_idx, 189, 4
175-
)
176-
crossline_bytes_disk = self.extract_header_bytes_from_file(
177-
segy_path, trace_idx, 193, 4
178-
)
174+
inline_bytes_disk = self.extract_header_bytes_from_file(segy_path, trace_idx, 189, 4)
175+
crossline_bytes_disk = self.extract_header_bytes_from_file(segy_path, trace_idx, 193, 4)
179176

180177
# Convert raw headers to bytes for comparison
181178
if raw_headers is not None:
@@ -185,30 +182,30 @@ def test_header_validation_configurations(
185182
if raw_headers.ndim == 0:
186183
# Single trace case
187184
raw_data_bytes = raw_headers.tobytes()
188-
inline_offset = raw_headers.dtype.fields['inline'][1]
189-
crossline_offset = raw_headers.dtype.fields['crossline'][1]
190-
inline_size = raw_headers.dtype.fields['inline'][0].itemsize
191-
crossline_size = raw_headers.dtype.fields['crossline'][0].itemsize
192-
185+
inline_offset = raw_headers.dtype.fields["inline"][1]
186+
crossline_offset = raw_headers.dtype.fields["crossline"][1]
187+
inline_size = raw_headers.dtype.fields["inline"][0].itemsize
188+
crossline_size = raw_headers.dtype.fields["crossline"][0].itemsize
189+
193190
raw_inline_bytes = np.frombuffer(
194-
raw_data_bytes[inline_offset:inline_offset+inline_size], dtype=np.uint8
191+
raw_data_bytes[inline_offset : inline_offset + inline_size], dtype=np.uint8
195192
)
196193
raw_crossline_bytes = np.frombuffer(
197-
raw_data_bytes[crossline_offset:crossline_offset+crossline_size], dtype=np.uint8
194+
raw_data_bytes[crossline_offset : crossline_offset + crossline_size], dtype=np.uint8
198195
)
199196
else:
200197
# Multiple traces case - this test uses single trace index, so extract that trace
201198
raw_data_bytes = raw_headers[0:1].tobytes() # Extract first trace
202-
inline_offset = raw_headers.dtype.fields['inline'][1]
203-
crossline_offset = raw_headers.dtype.fields['crossline'][1]
204-
inline_size = raw_headers.dtype.fields['inline'][0].itemsize
205-
crossline_size = raw_headers.dtype.fields['crossline'][0].itemsize
206-
199+
inline_offset = raw_headers.dtype.fields["inline"][1]
200+
crossline_offset = raw_headers.dtype.fields["crossline"][1]
201+
inline_size = raw_headers.dtype.fields["inline"][0].itemsize
202+
crossline_size = raw_headers.dtype.fields["crossline"][0].itemsize
203+
207204
raw_inline_bytes = np.frombuffer(
208-
raw_data_bytes[inline_offset:inline_offset+inline_size], dtype=np.uint8
205+
raw_data_bytes[inline_offset : inline_offset + inline_size], dtype=np.uint8
209206
)
210207
raw_crossline_bytes = np.frombuffer(
211-
raw_data_bytes[crossline_offset:crossline_offset+crossline_size], dtype=np.uint8
208+
raw_data_bytes[crossline_offset : crossline_offset + crossline_size], dtype=np.uint8
212209
)
213210

214211
print(f"Transformed headers: {transformed_headers.tobytes()}")
@@ -217,10 +214,12 @@ def test_header_validation_configurations(
217214
print(f"Crossline bytes disk: {crossline_bytes_disk.tobytes()}")
218215

219216
# Compare bytes
220-
assert np.array_equal(raw_inline_bytes, inline_bytes_disk), \
217+
assert np.array_equal(raw_inline_bytes, inline_bytes_disk), (
221218
f"Inline bytes mismatch for trace {trace_idx} in {config_name}"
222-
assert np.array_equal(raw_crossline_bytes, crossline_bytes_disk), \
219+
)
220+
assert np.array_equal(raw_crossline_bytes, crossline_bytes_disk), (
223221
f"Crossline bytes mismatch for trace {trace_idx} in {config_name}"
222+
)
224223

225224
def test_header_validation_no_transforms(
226225
self, temp_dir: Path, basic_segy_spec: SegySpec, segy_config: dict
@@ -252,7 +251,7 @@ def test_header_validation_no_transforms(
252251
raw_headers, transformed_headers, traces = get_header_raw_and_transformed(
253252
segy_file=segy_file,
254253
indices=slice(None), # All traces
255-
do_reverse_transforms=False
254+
do_reverse_transforms=False,
256255
)
257256

258257
# When transforms are disabled, raw_headers should be None
@@ -262,13 +261,8 @@ def test_header_validation_no_transforms(
262261
assert transformed_headers is not None
263262
assert transformed_headers.size == num_traces
264263

265-
def test_multiple_traces_validation(
266-
self, temp_dir: Path, basic_segy_spec: SegySpec, segy_config: dict
267-
) -> None:
264+
def test_multiple_traces_validation(self, temp_dir: Path, basic_segy_spec: SegySpec, segy_config: dict) -> None:
268265
"""Test validation with multiple traces at once."""
269-
if True:
270-
import segy
271-
print(segy.__version__)
272266
config_name = segy_config["name"]
273267
endianness = segy_config["endianness"]
274268
data_format = segy_config["data_format"]
@@ -301,20 +295,16 @@ def test_multiple_traces_validation(
301295
raw_headers, transformed_headers, traces = get_header_raw_and_transformed(
302296
segy_file=segy_file,
303297
indices=slice(None), # All traces
304-
do_reverse_transforms=True
298+
do_reverse_transforms=True,
305299
)
306300

307301
first = True
308302

309303
# Validate each trace
310304
for trace_idx in range(num_traces):
311305
# Extract bytes from disk
312-
inline_bytes_disk = self.extract_header_bytes_from_file(
313-
segy_path, trace_idx, 189, 4
314-
)
315-
crossline_bytes_disk = self.extract_header_bytes_from_file(
316-
segy_path, trace_idx, 193, 4
317-
)
306+
inline_bytes_disk = self.extract_header_bytes_from_file(segy_path, trace_idx, 189, 4)
307+
crossline_bytes_disk = self.extract_header_bytes_from_file(segy_path, trace_idx, 193, 4)
318308

319309
if first:
320310
print(raw_headers.dtype)
@@ -327,30 +317,30 @@ def test_multiple_traces_validation(
327317
if raw_headers.ndim == 0:
328318
# Single trace case
329319
raw_data_bytes = raw_headers.tobytes()
330-
inline_offset = raw_headers.dtype.fields['inline'][1]
331-
crossline_offset = raw_headers.dtype.fields['crossline'][1]
332-
inline_size = raw_headers.dtype.fields['inline'][0].itemsize
333-
crossline_size = raw_headers.dtype.fields['crossline'][0].itemsize
334-
320+
inline_offset = raw_headers.dtype.fields["inline"][1]
321+
crossline_offset = raw_headers.dtype.fields["crossline"][1]
322+
inline_size = raw_headers.dtype.fields["inline"][0].itemsize
323+
crossline_size = raw_headers.dtype.fields["crossline"][0].itemsize
324+
335325
raw_inline_bytes = np.frombuffer(
336-
raw_data_bytes[inline_offset:inline_offset+inline_size], dtype=np.uint8
326+
raw_data_bytes[inline_offset : inline_offset + inline_size], dtype=np.uint8
337327
)
338328
raw_crossline_bytes = np.frombuffer(
339-
raw_data_bytes[crossline_offset:crossline_offset+crossline_size], dtype=np.uint8
329+
raw_data_bytes[crossline_offset : crossline_offset + crossline_size], dtype=np.uint8
340330
)
341331
else:
342332
# Multiple traces case
343-
raw_data_bytes = raw_headers[trace_idx:trace_idx+1].tobytes()
344-
inline_offset = raw_headers.dtype.fields['inline'][1]
345-
crossline_offset = raw_headers.dtype.fields['crossline'][1]
346-
inline_size = raw_headers.dtype.fields['inline'][0].itemsize
347-
crossline_size = raw_headers.dtype.fields['crossline'][0].itemsize
348-
333+
raw_data_bytes = raw_headers[trace_idx : trace_idx + 1].tobytes()
334+
inline_offset = raw_headers.dtype.fields["inline"][1]
335+
crossline_offset = raw_headers.dtype.fields["crossline"][1]
336+
inline_size = raw_headers.dtype.fields["inline"][0].itemsize
337+
crossline_size = raw_headers.dtype.fields["crossline"][0].itemsize
338+
349339
raw_inline_bytes = np.frombuffer(
350-
raw_data_bytes[inline_offset:inline_offset+inline_size], dtype=np.uint8
340+
raw_data_bytes[inline_offset : inline_offset + inline_size], dtype=np.uint8
351341
)
352342
raw_crossline_bytes = np.frombuffer(
353-
raw_data_bytes[crossline_offset:crossline_offset+crossline_size], dtype=np.uint8
343+
raw_data_bytes[crossline_offset : crossline_offset + crossline_size], dtype=np.uint8
354344
)
355345

356346
print(f"Raw inline bytes: {raw_inline_bytes.tobytes()}")
@@ -359,18 +349,23 @@ def test_multiple_traces_validation(
359349
print(f"Crossline bytes disk: {crossline_bytes_disk.tobytes()}")
360350

361351
# Compare
362-
assert np.array_equal(raw_inline_bytes, inline_bytes_disk), \
352+
assert np.array_equal(raw_inline_bytes, inline_bytes_disk), (
363353
f"Inline bytes mismatch for trace {trace_idx} in {config_name}"
364-
assert np.array_equal(raw_crossline_bytes, crossline_bytes_disk), \
354+
)
355+
assert np.array_equal(raw_crossline_bytes, crossline_bytes_disk), (
365356
f"Crossline bytes mismatch for trace {trace_idx} in {config_name}"
357+
)
366358

367-
@pytest.mark.parametrize("trace_indices", [
368-
0, # Single trace
369-
[0, 2, 4], # Multiple specific traces
370-
slice(1, 4), # Range of traces
371-
])
359+
@pytest.mark.parametrize(
360+
"trace_indices",
361+
[
362+
0, # Single trace
363+
[0, 2, 4], # Multiple specific traces
364+
slice(1, 4), # Range of traces
365+
],
366+
)
372367
def test_different_index_types(
373-
self, temp_dir: Path, basic_segy_spec: SegySpec, segy_config: dict, trace_indices
368+
self, temp_dir: Path, basic_segy_spec: SegySpec, segy_config: dict, trace_indices: int | list[int] | slice
374369
) -> None:
375370
"""Test with different types of trace indices."""
376371
config_name = segy_config["name"]
@@ -397,9 +392,7 @@ def test_different_index_types(
397392

398393
# Get headers with different index types
399394
raw_headers, transformed_headers, traces = get_header_raw_and_transformed(
400-
segy_file=segy_file,
401-
indices=trace_indices,
402-
do_reverse_transforms=True
395+
segy_file=segy_file, indices=trace_indices, do_reverse_transforms=True
403396
)
404397

405398
# Basic validation that we got results

0 commit comments

Comments
 (0)