Skip to content

Commit 3d69893

Browse files
committed
Begin work on duplicate index override
1 parent 1417d97 commit 3d69893

File tree

7 files changed

+167
-13
lines changed

7 files changed

+167
-13
lines changed

src/mdio/builder/dataset_builder.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,100 @@ def add_dimension(self, name: str, size: int) -> "MDIODatasetBuilder":
101101
self._state = _BuilderState.HAS_DIMENSIONS
102102
return self
103103

104+
def push_dimension(self, dimension: NamedDimension, position: int, new_dim_chunk_size: int=1) -> "MDIODatasetBuilder":
105+
"""Pushes a dimension to all Coordiantes and Variables.
106+
The position argument is the domain index of the dimension to push.
107+
If a Variable is within the position domain, it will be inserted at the position and all remaining dimensions will be shifted to the right.
108+
109+
Args:
110+
dimension: The dimension to push
111+
position: The position to push the dimension to
112+
new_dim_chunk_size: The chunk size for only the new dimension
113+
114+
Returns:
115+
self: Returns self for method chaining
116+
"""
117+
if position < 0:
118+
msg = "Support for negative positions is not implemented yet!"
119+
raise ValueError(msg)
120+
if position > len(self._dimensions):
121+
msg = "Position is greater than the number of dimensions"
122+
raise ValueError(msg)
123+
if new_dim_chunk_size <= 0:
124+
# TODO(BrianMichell): Do we actually need to check this, or does Pydantic handle when we call?
125+
msg = "New dimension chunk size must be greater than 0"
126+
raise ValueError(msg)
127+
128+
# print("###########################STATE BEFORE INSERTING DIMENSION ###########################")
129+
# for d in self._dimensions:
130+
# print(d.model_dump_json())
131+
# for c in self._coordinates:
132+
# print(c.model_dump_json())
133+
# for v in self._variables:
134+
# print(v.model_dump_json())
135+
# print("########################################################################################")
136+
137+
138+
# In-place insertion of the dimension to the existing list of dimensions
139+
self._dimensions.insert(position, dimension)
140+
141+
# def propogate_dimension(variable: Variable, position: int, new_dim_chunk_size: int) -> Variable:
142+
# """Propogates the dimension to the variable or coordinate."""
143+
# if len(variable.dimensions) <= position:
144+
# # Don't do anything if the new dimension is not within the Variable's domain
145+
# return variable
146+
# if variable.name == "trace_mask":
147+
# # Special case for trace_mask. Don't do anything.
148+
# return variable
149+
# # new_dimensions = variable.dimensions[:position] + (dimension,) + variable.dimensions[position:]
150+
# # new_chunk_sizes = variable.chunk_sizes[:position] + (new_dim_chunk_size,) + variable.chunk_sizes[position:]
151+
# new_dimensions = variable.dimensions[:position] + [dimension] + variable.dimensions[position:]
152+
# new_chunk_sizes = variable.chunk_sizes[:position] + [new_dim_chunk_size] + variable.chunk_sizes[position:]
153+
# return variable.model_copy(update={"dimensions": new_dimensions, "chunk_sizes": new_chunk_sizes})
154+
def propogate_dimension(variable: Variable, position: int, new_dim_chunk_size: int) -> Variable:
155+
"""Propogates the dimension to the variable or coordinate."""
156+
from mdio.builder.schemas.chunk_grid import RegularChunkGrid, RegularChunkShape
157+
if len(variable.dimensions) <= position:
158+
# Don't do anything if the new dimension is not within the Variable's domain
159+
return variable
160+
if variable.name == "trace_mask":
161+
# Special case for trace_mask. Don't do anything.
162+
return variable
163+
# new_dimensions = variable.dimensions[:position] + (dimension,) + variable.dimensions[position:]
164+
# new_chunk_sizes = variable.chunk_sizes[:position] + (new_dim_chunk_size,) + variable.chunk_sizes[position:]
165+
new_dimensions = variable.dimensions[:position] + [dimension] + variable.dimensions[position:]
166+
167+
# Get current chunk shape from metadata
168+
current_chunk_shape = (1,) * len(variable.dimensions) # Default fallback
169+
if variable.metadata is not None and variable.metadata.chunk_grid is not None:
170+
current_chunk_shape = variable.metadata.chunk_grid.configuration.chunk_shape
171+
172+
# Insert new chunk size at the correct position
173+
new_chunk_shape = current_chunk_shape[:position] + (new_dim_chunk_size,) + current_chunk_shape[position:]
174+
175+
# Create new chunk grid configuration
176+
new_chunk_grid = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=new_chunk_shape))
177+
178+
# Update metadata with new chunk grid
179+
new_metadata = variable.metadata.model_copy() if variable.metadata else VariableMetadata()
180+
new_metadata.chunk_grid = new_chunk_grid
181+
ret = variable.model_copy(update={"dimensions": new_dimensions, "metadata": new_metadata})
182+
return ret
183+
184+
to_ignore = []
185+
for v in self._dimensions:
186+
to_ignore.append(v.name)
187+
for c in self._coordinates:
188+
to_ignore.append(c.name)
189+
190+
for i in range(len(self._variables)):
191+
var = self._variables[i]
192+
if var.name in to_ignore:
193+
continue
194+
self._variables[i] = propogate_dimension(var, position, new_dim_chunk_size)
195+
196+
return self
197+
104198
def add_coordinate( # noqa: PLR0913
105199
self,
106200
name: str,

src/mdio/builder/templates/abstract_dataset_template.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from abc import ABC
55
from abc import abstractmethod
66
from typing import Any
7+
from typing import Callable
78

89
from mdio.builder.dataset_builder import MDIODatasetBuilder
910
from mdio.builder.schemas import compressors
@@ -40,6 +41,11 @@ def __init__(self, data_domain: SeismicDataDomain) -> None:
4041
self._builder: MDIODatasetBuilder | None = None
4142
self._dim_sizes = ()
4243
self._horizontal_coord_unit = None
44+
self._queued_transforms = []
45+
46+
def _queue_transform(self, transform: Callable) -> None:
47+
"""Queue a transform to be applied to the dataset once it has been built."""
48+
self._queued_transforms.append(transform)
4349

4450
def build_dataset(
4551
self,
@@ -71,6 +77,10 @@ def build_dataset(
7177
self._add_trace_mask()
7278
if header_dtype:
7379
self._add_trace_headers(header_dtype)
80+
81+
print(f"Number of queued transforms: {len(self._queued_transforms)}")
82+
for transform in self._queued_transforms:
83+
transform(self._builder)
7484
return self._builder.build()
7585

7686
@property

src/mdio/converters/segy.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,15 @@ def _scan_for_headers(
134134
grid_overrides=grid_overrides,
135135
)
136136
if full_chunk_size != chunk_size:
137+
pass
137138
# TODO(Dmitriy): implement grid overrides
138139
# https://github.com/TGSAI/mdio-python/issues/585
139140
# The returned 'chunksize' is used only for grid_overrides. We will need to use it when full
140141
# support for grid overrides is implemented
141-
err = "Support for changing full_chunk_size in grid overrides is not yet implemented"
142-
raise NotImplementedError(err)
142+
# err = "Support for changing full_chunk_size in grid overrides is not yet implemented"
143+
# raise NotImplementedError(err)
144+
145+
full_chunk_size = template.full_chunk_size
143146
return segy_dimensions, segy_headers
144147

145148

@@ -157,6 +160,7 @@ def _build_and_check_grid(segy_dimensions: list[Dimension], segy_file: SegyFile,
157160
Raises:
158161
GridTraceCountError: If number of traces in SEG-Y file does not match the parsed grid
159162
"""
163+
# print(segy_dimensions)
160164
grid = Grid(dims=segy_dimensions)
161165
grid_density_qc(grid, segy_file.num_traces)
162166
grid.build_map(segy_headers)
@@ -370,20 +374,25 @@ def segy_to_mdio( # noqa PLR0913
370374

371375
grid = _build_and_check_grid(segy_dimensions, segy_file, segy_headers)
372376

373-
_, non_dim_coords = _get_coordinates(grid, segy_headers, mdio_template)
374-
header_dtype = to_structured_type(segy_spec.trace.header.dtype)
375-
horizontal_unit = _get_horizontal_coordinate_unit(segy_dimensions)
376377
mdio_ds: Dataset = mdio_template.build_dataset(
377378
name=mdio_template.name,
378379
sizes=grid.shape,
379-
horizontal_coord_unit=horizontal_unit,
380-
header_dtype=header_dtype,
380+
horizontal_coord_unit=_get_horizontal_coordinate_unit(segy_dimensions),
381+
header_dtype=to_structured_type(segy_spec.trace.header.dtype),
381382
)
382383

384+
# print(mdio_ds.model_dump_json())
385+
for v in mdio_ds.variables:
386+
print(f"Attempting to dump variable {v.name}... ", end="")
387+
tmp = v.model_dump_json()
388+
print("Good!")
389+
# print(v.model_dump_json())
390+
383391
_add_grid_override_to_metadata(dataset=mdio_ds, grid_overrides=grid_overrides)
384392

385393
xr_dataset: xr_Dataset = to_xarray_dataset(mdio_ds=mdio_ds)
386394

395+
_, non_dim_coords = _get_coordinates(grid, segy_headers, mdio_template)
387396
xr_dataset, drop_vars_delayed = _populate_coordinates(
388397
dataset=xr_dataset,
389398
grid=grid,

src/mdio/segy/geometry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,12 @@ def run(
601601
chunksize: Sequence[int] | None = None,
602602
) -> tuple[HeaderArray, tuple[str], tuple[int]]:
603603
"""Run grid overrides and return result."""
604+
605+
# print("="*100)
606+
# print(index_headers.to_dict().keys())
607+
# print(index_headers)
608+
# print("="*100)
609+
604610
for override in grid_overrides:
605611
if override in self.parameters:
606612
continue

src/mdio/segy/parsers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414

1515
from mdio.segy._workers import header_scan_worker
1616

17+
from segy.arrays import HeaderArray
1718
if TYPE_CHECKING:
1819
from segy import SegyFile
19-
from segy.arrays import HeaderArray
2020

2121
default_cpus = cpu_count(logical=True)
2222

@@ -73,4 +73,4 @@ def parse_headers(
7373
headers: list[HeaderArray] = list(lazy_work)
7474

7575
# Merge blocks before return
76-
return np.concatenate(headers)
76+
return HeaderArray(np.concatenate(headers))

src/mdio/segy/utilities.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from mdio.core import Dimension
1414
from mdio.segy.geometry import GridOverrider
1515
from mdio.segy.parsers import parse_headers
16+
from mdio.builder.schemas.dimension import NamedDimension
1617

1718
if TYPE_CHECKING:
1819
from numpy.typing import DTypeLike
@@ -57,16 +58,48 @@ def get_grid_plan( # noqa: C901
5758
horizontal_dimensions = template.dimension_names[:-1]
5859
horizontal_coordinates = horizontal_dimensions + template.coordinate_names
5960
headers_subset = parse_headers(segy_file=segy_file, subset=horizontal_coordinates)
61+
from segy.arrays import HeaderArray
62+
63+
# reduced_headers_subset = HeaderArray(h for h in headers_subset.to_dict().keys() if h in horizontal_dimensions)
64+
# Get field names to keep
65+
fields_to_keep = [h for h in headers_subset.dtype.names if h in horizontal_dimensions]
66+
67+
# Create filtered copy using numpy's field selection
68+
reduced_headers_subset = HeaderArray(headers_subset[fields_to_keep])
69+
70+
# print("="*100)
71+
# print(headers_subset.to_dict().keys())
72+
# print(headers_subset)
73+
# print("="*100)
6074

6175
# Handle grid overrides.
6276
override_handler = GridOverrider()
6377
headers_subset, horizontal_coordinates, chunksize = override_handler.run(
64-
headers_subset,
78+
# headers_subset,
79+
reduced_headers_subset,
6580
horizontal_coordinates,
81+
# horizontal_dimensions,
6682
chunksize=chunksize,
6783
grid_overrides=grid_overrides,
6884
)
6985

86+
# print("="*100)
87+
# print(headers_subset.to_dict().keys())
88+
# print(headers_subset)
89+
# print("="*100)
90+
91+
if grid_overrides.get("HasDuplicates", False):
92+
# print(f"Size of header subset: {headers_subset['trace'].size}")
93+
# print("="*100)
94+
# print(headers_subset)
95+
# print("="*100)
96+
pos = len(template.dimension_names) - 1 # TODO: Implement the negative position case...
97+
# template._queue_transform(lambda builder: builder.push_dimension(Dimension(coords=np.arange(headers_subset["trace"].size), name="trace"), position=pos, new_dim_chunk_size=1))
98+
# template._queue_transform(lambda builder: builder.push_dimension(Dimension(coords=headers_subset["trace"].size, name="trace"), position=pos, new_dim_chunk_size=1))
99+
template._queue_transform(lambda builder: builder.push_dimension(NamedDimension(name="trace", size=headers_subset["trace"].size), position=pos, new_dim_chunk_size=1))
100+
# horizontal_dimensions.append("trace")
101+
horizontal_dimensions = (*horizontal_dimensions, "trace")
102+
70103
dimensions = []
71104
for dim_name in horizontal_dimensions:
72105
dim_unique = np.unique(headers_subset[dim_name])

tests/integration/test_import_streamer_grid_overrides.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@
3030

3131
# TODO(Altay): Finish implementing these grid overrides.
3232
# https://github.com/TGSAI/mdio-python/issues/612
33-
@pytest.mark.skip(reason="NonBinned and HasDuplicates haven't been properly implemented yet.")
34-
@pytest.mark.parametrize("grid_override", [{"NonBinned": True}, {"HasDuplicates": True}])
33+
# @pytest.mark.skip(reason="NonBinned and HasDuplicates haven't been properly implemented yet.")
34+
# @pytest.mark.parametrize("grid_override", [{"NonBinned": True}, {"HasDuplicates": True}])
35+
@pytest.mark.parametrize("grid_override", [{"HasDuplicates": True}])
3536
@pytest.mark.parametrize("chan_header_type", [StreamerShotGeometryType.C])
3637
class TestImport4DNonReg:
3738
"""Test for 4D segy import with grid overrides."""
@@ -51,7 +52,8 @@ def test_import_4d_segy( # noqa: PLR0913
5152
segy_spec=segy_spec,
5253
mdio_template=TemplateRegistry().get("PreStackShotGathers3DTime"),
5354
input_path=segy_path,
54-
output_path=zarr_tmp,
55+
# output_path=zarr_tmp,
56+
output_path="test_has_duplicates.mdio",
5557
overwrite=True,
5658
grid_overrides=grid_override,
5759
)

0 commit comments

Comments
 (0)