| 
32 | 32 | from mdio.converters.exceptions import GridTraceSparsityError  | 
33 | 33 | from mdio.converters.type_converter import to_structured_type  | 
34 | 34 | from mdio.core.grid import Grid  | 
 | 35 | +from mdio.core.utils_write import MAX_COORDINATES_BYTES  | 
 | 36 | +from mdio.core.utils_write import MAX_SIZE_LIVE_MASK  | 
 | 37 | +from mdio.core.utils_write import get_constrained_chunksize  | 
35 | 38 | from mdio.segy import blocked_io  | 
36 | 39 | from mdio.segy.utilities import get_grid_plan  | 
37 | 40 | 
 
  | 
@@ -429,6 +432,36 @@ def enhanced_add_variables() -> None:  | 
429 | 432 |     return mdio_template  | 
430 | 433 | 
 
  | 
431 | 434 | 
 
  | 
 | 435 | +def _chunk_variable(ds: Dataset, variable_name: str) -> None:  | 
 | 436 | +    """Determines the chunking for a Varible in the Dataset."""  | 
 | 437 | +    idx = -1  | 
 | 438 | +    for i in range(len(ds.variables)):  | 
 | 439 | +        if ds.variables[i].name == variable_name:  | 
 | 440 | +            idx = i  | 
 | 441 | +            break  | 
 | 442 | + | 
 | 443 | +    def determine_target_size(var_type: str) -> int:  | 
 | 444 | +        """Determines the target size (in bytes) for a Variable based on its type."""  | 
 | 445 | +        if var_type == "bool":  | 
 | 446 | +            return MAX_SIZE_LIVE_MASK  | 
 | 447 | +        return MAX_COORDINATES_BYTES  | 
 | 448 | + | 
 | 449 | +    # Create the chunk grid metadata  | 
 | 450 | +    var_type = ds.variables[idx].data_type  | 
 | 451 | +    full_shape = tuple(dim.size for dim in ds.variables[idx].dimensions)  | 
 | 452 | +    target_size = determine_target_size(var_type)  | 
 | 453 | + | 
 | 454 | +    chunk_shape = get_constrained_chunksize(full_shape, var_type, target_size)  | 
 | 455 | +    chunks = RegularChunkGrid(configuration=RegularChunkShape(chunk_shape=chunk_shape))  | 
 | 456 | + | 
 | 457 | +    # Update the variable's metadata with the new chunk grid  | 
 | 458 | +    if ds.variables[idx].metadata is None:  | 
 | 459 | +        # ds.variables[idx].metadata = VariableMetadata(chunk_shape=chunks.chunk_shape)  | 
 | 460 | +        ds.variables[idx].metadata = VariableMetadata(chunk_grid=chunks)  | 
 | 461 | +    else:  | 
 | 462 | +        ds.variables[idx].metadata.chunk_grid = chunks  | 
 | 463 | + | 
 | 464 | + | 
432 | 465 | def segy_to_mdio(  # noqa PLR0913  | 
433 | 466 |     segy_spec: SegySpec,  | 
434 | 467 |     mdio_template: AbstractDatasetTemplate,  | 
@@ -487,6 +520,10 @@ def segy_to_mdio(  # noqa PLR0913  | 
487 | 520 | 
 
  | 
488 | 521 |     _add_grid_override_to_metadata(dataset=mdio_ds, grid_overrides=grid_overrides)  | 
489 | 522 | 
 
  | 
 | 523 | +    _chunk_variable(ds=mdio_ds, variable_name="trace_mask")  # trace_mask is a Variable and not a Coordinate  | 
 | 524 | +    for coord in mdio_template.coordinate_names:  | 
 | 525 | +        _chunk_variable(ds=mdio_ds, variable_name=coord)  | 
 | 526 | + | 
490 | 527 |     xr_dataset: xr_Dataset = to_xarray_dataset(mdio_ds=mdio_ds)  | 
491 | 528 | 
 
  | 
492 | 529 |     xr_dataset, drop_vars_delayed = _populate_coordinates(  | 
 | 
0 commit comments