|
32 | 32 | from mdio.converters.exceptions import GridTraceSparsityError |
33 | 33 | from mdio.converters.type_converter import to_structured_type |
34 | 34 | from mdio.core.grid import Grid |
| 35 | +from mdio.core.utils_write import MAX_COORDINATES_BYTES |
| 36 | +from mdio.core.utils_write import MAX_SIZE_LIVE_MASK |
| 37 | +from mdio.core.utils_write import get_constrained_chunksize |
35 | 38 | from mdio.segy import blocked_io |
36 | 39 | from mdio.segy.utilities import get_grid_plan |
37 | 40 |
|
@@ -429,6 +432,32 @@ def enhanced_add_variables() -> None: |
429 | 432 | return mdio_template |
430 | 433 |
|
431 | 434 |
|
| 435 | +def _chunk_variable(ds: Dataset, target_variable_name: str) -> None: |
| 436 | + """Determines and sets the chunking for a specific Variable in the Dataset.""" |
| 437 | + # Find variable index by name |
| 438 | + index = next((i for i, obj in enumerate(ds.variables) if obj.name == target_variable_name), None) |
| 439 | + |
| 440 | + def determine_target_size(var_type: str) -> int: |
| 441 | + """Determines the target size (in bytes) for a Variable based on its type.""" |
| 442 | + if var_type == "bool": |
| 443 | + return MAX_SIZE_LIVE_MASK |
| 444 | + return MAX_COORDINATES_BYTES |
| 445 | + |
| 446 | + # Create the chunk grid metadata |
| 447 | + var_type = ds.variables[index].data_type |
| 448 | + full_shape = tuple(dim.size for dim in ds.variables[index].dimensions) |
| 449 | + target_size = determine_target_size(var_type) |
| 450 | + |
| 451 | + chunk_shape = get_constrained_chunksize(full_shape, var_type, target_size) |
| 452 | + chunk_grid = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=chunk_shape)) |
| 453 | + |
| 454 | + # Create variable metadata if it doesn't exist |
| 455 | + if ds.variables[index].metadata is None: |
| 456 | + ds.variables[index].metadata = VariableMetadata() |
| 457 | + |
| 458 | + ds.variables[index].metadata.chunk_grid = chunk_grid |
| 459 | + |
| 460 | + |
432 | 461 | def segy_to_mdio( # noqa PLR0913 |
433 | 462 | segy_spec: SegySpec, |
434 | 463 | mdio_template: AbstractDatasetTemplate, |
@@ -487,6 +516,11 @@ def segy_to_mdio( # noqa PLR0913 |
487 | 516 |
|
488 | 517 | _add_grid_override_to_metadata(dataset=mdio_ds, grid_overrides=grid_overrides) |
489 | 518 |
|
| 519 | + # Dynamically chunk the variables based on their type |
| 520 | + _chunk_variable(ds=mdio_ds, target_variable_name="trace_mask") # trace_mask is a Variable and not a Coordinate |
| 521 | + for coord in mdio_template.coordinate_names: |
| 522 | + _chunk_variable(ds=mdio_ds, target_variable_name=coord) |
| 523 | + |
490 | 524 | xr_dataset: xr_Dataset = to_xarray_dataset(mdio_ds=mdio_ds) |
491 | 525 |
|
492 | 526 | xr_dataset, drop_vars_delayed = _populate_coordinates( |
|
0 commit comments