2424 from numpy .typing import NDArray
2525 from segy .arrays import HeaderArray
2626
27+ from mdio .builder .templates .base import AbstractDatasetTemplate
28+
2729
2830logger = logging .getLogger (__name__ )
2931
@@ -303,6 +305,7 @@ def transform(
303305 self ,
304306 index_headers : HeaderArray ,
305307 grid_overrides : dict [str , bool | int ],
308+ template : AbstractDatasetTemplate , # noqa: ARG002
306309 ) -> NDArray :
307310 """Perform the grid transform."""
308311
@@ -379,42 +382,38 @@ def transform(
379382 self ,
380383 index_headers : HeaderArray ,
381384 grid_overrides : dict [str , bool | int ],
385+ template : AbstractDatasetTemplate , # noqa: ARG002
382386 ) -> NDArray :
383387 """Perform the grid transform."""
384388 self .validate (index_headers , grid_overrides )
385389
386390 # Filter to only include dimension fields, not coordinate fields
387- # Coordinate fields typically have _x, _y suffixes or specific names like 'gun'
388- # We want to keep fields like shot_point, cable, channel but exclude source_coord_x, etc.
391+ # We want to keep fields like shot_point, cable, channel but exclude coordinate fields
392+ # Use the template's coordinate names to determine which fields are coordinates
393+ coordinate_fields = set (template .coordinate_names )
389394 dimension_fields = []
390- coordinate_field_patterns = ['_x' , '_y' , '_coord' , 'gun' , 'source' , 'group' ]
391-
395+
392396 for field_name in index_headers .dtype .names :
393397 # Skip if it's already trace
394- if field_name == ' trace' :
398+ if field_name == " trace" :
395399 continue
396- # Check if it looks like a coordinate field
397- is_coordinate = any (pattern in field_name for pattern in coordinate_field_patterns )
398- if not is_coordinate :
400+ # Check if this field is a coordinate field according to the template
401+ if field_name not in coordinate_fields :
399402 dimension_fields .append (field_name )
400-
403+
401404 # Extract only dimension fields for trace indexing
402- if dimension_fields :
403- dimension_headers = index_headers [dimension_fields ]
404- else :
405- # If no dimension fields, use all fields
406- dimension_headers = index_headers
407-
405+ dimension_headers = index_headers [dimension_fields ] if dimension_fields else index_headers
406+
408407 # Create trace indices based on dimension fields only
409408 dimension_headers_with_trace = analyze_non_indexed_headers (dimension_headers )
410-
409+
411410 # Add the trace field back to the full index_headers array
412- if dimension_headers_with_trace is not None and ' trace' in dimension_headers_with_trace .dtype .names :
411+ if dimension_headers_with_trace is not None and " trace" in dimension_headers_with_trace .dtype .names :
413412 # Extract just the trace values array (not the whole structured array)
414- trace_values = np .array (dimension_headers_with_trace [' trace' ])
413+ trace_values = np .array (dimension_headers_with_trace [" trace" ])
415414 # Append as a new field to the full headers
416- index_headers = rfn .append_fields (index_headers , ' trace' , trace_values , usemask = False )
417-
415+ index_headers = rfn .append_fields (index_headers , " trace" , trace_values , usemask = False )
416+
418417 return index_headers
419418
420419 def transform_index_names (self , index_names : Sequence [str ]) -> Sequence [str ]:
@@ -467,6 +466,7 @@ def transform(
467466 self ,
468467 index_headers : HeaderArray ,
469468 grid_overrides : dict [str , bool | int ],
469+ template : AbstractDatasetTemplate , # noqa: ARG002
470470 ) -> NDArray :
471471 """Perform the grid transform."""
472472 self .validate (index_headers , grid_overrides )
@@ -504,6 +504,7 @@ def transform(
504504 self ,
505505 index_headers : HeaderArray ,
506506 grid_overrides : dict [str , bool | int ],
507+ template : AbstractDatasetTemplate , # noqa: ARG002
507508 ) -> NDArray :
508509 """Perform the grid transform."""
509510 self .validate (index_headers , grid_overrides )
@@ -565,6 +566,7 @@ def run(
565566 index_names : Sequence [str ],
566567 grid_overrides : dict [str , bool ],
567568 chunksize : Sequence [int ] | None = None ,
569+ template : AbstractDatasetTemplate | None = None ,
568570 ) -> tuple [HeaderArray , tuple [str ], tuple [int ]]:
569571 """Run grid overrides and return result."""
570572 for override in grid_overrides :
@@ -575,7 +577,7 @@ def run(
575577 raise GridOverrideUnknownError (override )
576578
577579 function = self .commands [override ].transform
578- index_headers = function (index_headers , grid_overrides = grid_overrides )
580+ index_headers = function (index_headers , grid_overrides = grid_overrides , template = template )
579581
580582 function = self .commands [override ].transform_index_names
581583 index_names = function (index_names )
0 commit comments