Skip to content

Commit e4dd2f1

Browse files
BrianMichelltasansal
authored andcommitted
Reimplement disaster recovery logic
1 parent 3c84eee commit e4dd2f1

File tree

5 files changed

+78
-1
lines changed

5 files changed

+78
-1
lines changed

src/mdio/builder/schemas/dtype.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class ScalarType(StrEnum):
3232
COMPLEX64 = "complex64"
3333
COMPLEX128 = "complex128"
3434
COMPLEX256 = "complex256"
35+
HEADERS_V3 = "r1920" # Raw number of BITS, must be a multiple of 8
3536

3637

3738
class StructuredField(CamelCaseStrictModel):

src/mdio/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,5 @@ class ZarrFormat(IntEnum):
6464
ScalarType.COMPLEX64: complex(np.nan, np.nan),
6565
ScalarType.COMPLEX128: complex(np.nan, np.nan),
6666
ScalarType.COMPLEX256: complex(np.nan, np.nan),
67+
ScalarType.HEADERS_V3: b"\x00" * 240,
6768
}

src/mdio/converters/segy.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
from mdio.converters.exceptions import GridTraceSparsityError
2323
from mdio.converters.type_converter import to_structured_type
2424
from mdio.core.grid import Grid
25+
from mdio.builder.schemas.chunk_grid import RegularChunkGrid
26+
from mdio.builder.schemas.chunk_grid import RegularChunkShape
27+
from mdio.builder.schemas.compressors import Blosc
28+
from mdio.builder.schemas.compressors import BloscCname
29+
from mdio.builder.schemas.dtype import ScalarType
2530
from mdio.segy import blocked_io
2631
from mdio.segy.utilities import get_grid_plan
2732

@@ -333,6 +338,58 @@ def _add_grid_override_to_metadata(dataset: Dataset, grid_overrides: dict[str, A
333338
dataset.metadata.attributes["gridOverrides"] = grid_overrides
334339

335340

341+
def _add_raw_headers_to_template(mdio_template: AbstractDatasetTemplate) -> AbstractDatasetTemplate:
342+
"""Add raw headers capability to the MDIO template by monkey-patching its _add_variables method.
343+
This function modifies the template's _add_variables method to also add a raw headers variable
344+
with the following characteristics:
345+
- Same rank as the Headers variable (all dimensions except vertical)
346+
- Name: "RawHeaders"
347+
- Type: ScalarType.HEADERS
348+
- No coordinates
349+
- zstd compressor
350+
- No additional metadata
351+
- Chunked the same as the Headers variable
352+
Args:
353+
mdio_template: The MDIO template to mutate
354+
"""
355+
# Check if raw headers enhancement has already been applied to avoid duplicate additions
356+
if hasattr(mdio_template, '_mdio_raw_headers_enhanced'):
357+
return mdio_template
358+
359+
# Store the original _add_variables method
360+
original_add_variables = mdio_template._add_variables
361+
362+
def enhanced_add_variables() -> None:
363+
# Call the original method first
364+
original_add_variables()
365+
366+
# Now add the raw headers variable
367+
chunk_shape = mdio_template._var_chunk_shape[:-1]
368+
369+
# Create chunk grid metadata
370+
chunk_metadata = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=chunk_shape))
371+
from mdio.builder.schemas.v1.variable import VariableMetadata
372+
373+
# Add the raw headers variable using the builder's add_variable method
374+
mdio_template._builder.add_variable(
375+
name="raw_headers",
376+
long_name="Raw Headers",
377+
dimensions=mdio_template._dim_names[:-1], # All dimensions except vertical
378+
data_type=ScalarType.HEADERS_V3,
379+
compressor=Blosc(cname=BloscCname.zstd),
380+
coordinates=None, # No coordinates as specified
381+
metadata=VariableMetadata(chunk_grid=chunk_metadata),
382+
)
383+
384+
# Replace the template's _add_variables method
385+
mdio_template._add_variables = enhanced_add_variables
386+
387+
# Mark the template as enhanced to prevent duplicate monkey-patching
388+
mdio_template._mdio_raw_headers_enhanced = True
389+
390+
return mdio_template
391+
392+
336393
def segy_to_mdio( # noqa PLR0913
337394
segy_spec: SegySpec,
338395
mdio_template: AbstractDatasetTemplate,
@@ -372,6 +429,11 @@ def segy_to_mdio( # noqa PLR0913
372429

373430
_, non_dim_coords = _get_coordinates(grid, segy_headers, mdio_template)
374431
header_dtype = to_structured_type(segy_spec.trace.header.dtype)
432+
433+
if os.getenv("MDIO__DO_RAW_HEADERS") == "1":
434+
logger.warning("MDIO__DO_RAW_HEADERS is experimental and expected to change or be removed.")
435+
mdio_template = _add_raw_headers_to_template(mdio_template)
436+
375437
horizontal_unit = _get_horizontal_coordinate_unit(segy_dimensions)
376438
mdio_ds: Dataset = mdio_template.build_dataset(
377439
name=mdio_template.name,

src/mdio/converters/type_converter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def to_structured_type(data_type: np_dtype) -> StructuredType:
7878
def to_numpy_dtype(data_type: ScalarType | StructuredType) -> np_dtype:
7979
"""Get the numpy dtype for a variable."""
8080
if isinstance(data_type, ScalarType):
81+
if data_type == ScalarType.HEADERS_V3:
82+
return np_dtype("|V240")
8183
return np_dtype(data_type.value)
8284
if isinstance(data_type, StructuredType):
8385
return np_dtype([(f.name, f.format.value) for f in data_type.fields])

src/mdio/segy/_workers.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,15 @@ def trace_worker( # noqa: PLR0913
124124
traces = segy_file.trace[live_trace_indexes]
125125

126126
header_key = "headers"
127+
raw_header_key = "raw_headers"
127128

128129
# Get subset of the dataset that has not yet been saved
129130
# The headers might not be present in the dataset
130131
worker_variables = [data_variable_name]
131132
if header_key in dataset.data_vars: # Keeping the `if` here to allow for more worker configurations
132133
worker_variables.append(header_key)
134+
if raw_header_key in dataset.data_vars:
135+
worker_variables.append(raw_header_key)
133136

134137
ds_to_write = dataset[worker_variables]
135138

@@ -148,7 +151,15 @@ def trace_worker( # noqa: PLR0913
148151
attrs=ds_to_write[header_key].attrs,
149152
encoding=ds_to_write[header_key].encoding, # Not strictly necessary, but safer than not doing it.
150153
)
151-
154+
if raw_header_key in worker_variables:
155+
tmp_raw_headers = np.zeros_like(dataset[raw_header_key])
156+
tmp_raw_headers[not_null] = traces.header.view("|V240") # TODO: Ensure this is using the RAW view and not an interpreted view.
157+
ds_to_write[raw_header_key] = Variable(
158+
ds_to_write[raw_header_key].dims,
159+
tmp_raw_headers,
160+
attrs=ds_to_write[raw_header_key].attrs,
161+
encoding=ds_to_write[raw_header_key].encoding, # Not strictly necessary, but safer than not doing it.
162+
)
152163
data_variable = ds_to_write[data_variable_name]
153164
fill_value = _get_fill_value(ScalarType(data_variable.dtype.name))
154165
tmp_samples = np.full_like(data_variable, fill_value=fill_value)

0 commit comments

Comments
 (0)