@@ -110,12 +110,21 @@ class GridOverrideCommand(ABC):
110110 def required_keys (self ) -> set :
111111 """Get the set of required keys for the grid override command."""
112112
113+ @property
114+ @abstractmethod
115+ def required_parameters (self ) -> set :
116+ """Get the set of required parameters for the grid override command."""
117+
113118 @abstractmethod
114- def validate (self , index_headers : npt .NDArray , ** kwargs ) -> None :
119+ def validate (
120+ self , index_headers : npt .NDArray , grid_overrides : dict [str , bool | int ]
121+ ) -> None :
115122 """Validate if this transform should run on the type of data."""
116123
117124 @abstractmethod
118- def transform (self , index_headers : npt .NDArray , ** kwargs ) -> npt .NDArray :
125+ def transform (
126+ self , index_headers : npt .NDArray , grid_overrides : dict [str , bool | int ]
127+ ) -> dict [str , npt .NDArray ]:
119128 """Perform the grid transform."""
120129
121130 @property
@@ -129,25 +138,42 @@ def check_required_keys(self, index_headers: npt.NDArray) -> None:
129138 if not self .required_keys .issubset (index_names ):
130139 raise GridOverrideKeysError (self .name , self .required_keys )
131140
141+ def check_required_params (self , grid_overrides : dict [str , str | int ]) -> None :
142+ """Check if all required keys are present in the index headers."""
143+ if self .required_parameters is None :
144+ return
145+
146+ passed_parameters = set (grid_overrides .keys ())
147+
148+ if not self .required_parameters .issubset (passed_parameters ):
149+ missing_params = self .required_parameters - passed_parameters
150+ raise GridOverrideMissingParameterError (self .name , missing_params )
151+
132152
133153class AutoChannelWrap (GridOverrideCommand ):
134154 """Automatically determine Streamer acquisition type."""
135155
136156 required_keys = {"shot" , "cable" , "channel" }
157+ required_parameters = None
137158
138- def validate (self , index_headers : npt .NDArray , ** kwargs ) -> None :
159+ def validate (
160+ self , index_headers : npt .NDArray , grid_overrides : dict [str , bool | int ]
161+ ) -> None :
139162 """Validate if this transform should run on the type of data."""
140- self .check_required_keys (index_headers )
141-
142- if "ChannelWrap" in kwargs :
163+ if "ChannelWrap" in grid_overrides :
143164 raise GridOverrideIncompatibleError (self .name , "ChannelWrap" )
144165
145- if "CalculateCable" in kwargs :
166+ if "CalculateCable" in grid_overrides :
146167 raise GridOverrideIncompatibleError (self .name , "CalculateCable" )
147168
148- def transform (self , index_headers : npt .NDArray , ** kwargs ):
169+ self .check_required_keys (index_headers )
170+ self .check_required_params (grid_overrides )
171+
172+ def transform (
173+ self , index_headers : npt .NDArray , grid_overrides : dict [str , bool | int ]
174+ ) -> dict [str , npt .NDArray ]:
149175 """Perform the grid transform."""
150- self .validate (index_headers , ** kwargs )
176+ self .validate (index_headers , grid_overrides )
151177
152178 result = analyze_streamer_headers (index_headers )
153179 unique_cables , cable_chan_min , cable_chan_max , geom_type = result
@@ -179,22 +205,25 @@ class ChannelWrap(GridOverrideCommand):
179205 """Wrap channels to start from one at cable boundaries."""
180206
181207 required_keys = {"shot" , "cable" , "channel" }
208+ required_parameters = {"ChannelsPerCable" }
182209
183- def validate (self , index_headers : npt .NDArray , ** kwargs ) -> None :
210+ def validate (
211+ self , index_headers : npt .NDArray , grid_overrides : dict [str , bool | int ]
212+ ) -> None :
184213 """Validate if this transform should run on the type of data."""
185- self .check_required_keys (index_headers )
186-
187- if "ChannelsPerCable" not in kwargs :
188- raise GridOverrideMissingParameterError (self .name , "ChannelsPerCable" )
189-
190- if "AutoCableChannel" in kwargs :
214+ if "AutoChannelWrap" in grid_overrides :
191215 raise GridOverrideIncompatibleError (self .name , "AutoCableChannel" )
192216
193- def transform (self , index_headers : npt .NDArray , ** kwargs ) -> npt .NDArray :
217+ self .check_required_keys (index_headers )
218+ self .check_required_params (grid_overrides )
219+
220+ def transform (
221+ self , index_headers : npt .NDArray , grid_overrides : dict [str , bool | int ]
222+ ) -> dict [str , npt .NDArray ]:
194223 """Perform the grid transform."""
195- self .validate (index_headers , ** kwargs )
224+ self .validate (index_headers , grid_overrides )
196225
197- channels_per_cable = kwargs ["ChannelsPerCable" ]
226+ channels_per_cable = grid_overrides ["ChannelsPerCable" ]
198227 index_headers ["channel" ] = (
199228 index_headers ["channel" ] - 1
200229 ) % channels_per_cable + 1
@@ -206,20 +235,25 @@ class CalculateCable(GridOverrideCommand):
206235 """Calculate cable numbers from unwrapped channels."""
207236
208237 required_keys = {"shot" , "cable" , "channel" }
238+ required_parameters = {"ChannelsPerCable" }
209239
210- def validate (self , index_headers : npt .NDArray , ** kwargs ) -> None :
240+ def validate (
241+ self , index_headers : npt .NDArray , grid_overrides : dict [str , bool | int ]
242+ ) -> None :
211243 """Validate if this transform should run on the type of data."""
212- self .check_required_keys (index_headers )
213-
214- if "ChannelsPerCable" not in kwargs :
215- raise GridOverrideMissingParameterError (self .name , "ChannelsPerCable" )
216-
217- if "AutoCableChannel" in kwargs :
244+ if "AutoChannelWrap" in grid_overrides :
218245 raise GridOverrideIncompatibleError (self .name , "AutoCableChannel" )
219246
220- def transform (self , index_headers , ** kwargs ):
247+ self .check_required_keys (index_headers )
248+ self .check_required_params (grid_overrides )
249+
250+ def transform (
251+ self , index_headers , grid_overrides : dict [str , bool | int ]
252+ ) -> dict [str , npt .NDArray ]:
221253 """Perform the grid transform."""
222- channels_per_cable = kwargs ["ChannelsPerCable" ]
254+ self .validate (index_headers , grid_overrides )
255+
256+ channels_per_cable = grid_overrides ["ChannelsPerCable" ]
223257 index_headers ["cable" ] = (
224258 index_headers ["channel" ] - 1
225259 ) // channels_per_cable + 1
@@ -237,24 +271,40 @@ class GridOverrider:
237271 """
238272
239273 def __init__ (self ):
240- """Define allowed overrides here."""
274+ """Define allowed overrides and parameters here."""
241275 self .commands = {
242276 "AutoChannelWrap" : AutoChannelWrap (),
243277 "CalculateCable" : CalculateCable (),
244278 "ChannelWrap" : ChannelWrap (),
245279 }
246280
281+ self .parameters = self .get_allowed_parameters ()
282+
283+ def get_allowed_parameters (self ) -> set :
284+ """Get list of allowed parameters from the allowed commands."""
285+ parameters = set ()
286+ for command in self .commands .values ():
287+ if command .required_parameters is None :
288+ continue
289+
290+ parameters .update (command .required_parameters )
291+
292+ return parameters
293+
247294 def run (
248295 self ,
249296 index_headers : npt .NDArray ,
250297 grid_overrides : dict [str , bool ],
251298 ) -> npt .NDArray :
252299 """Run grid overrides and return result."""
253300 for override in grid_overrides :
254- if override in self .commands :
255- function = self . commands [ override ]. transform
256- index_headers = function ( index_headers , grid_overrides = grid_overrides )
257- else :
301+ if override in self .parameters :
302+ continue
303+
304+ if override not in self . commands :
258305 raise GridOverrideUnknownError (override )
259306
307+ function = self .commands [override ].transform
308+ index_headers = function (index_headers , grid_overrides = grid_overrides )
309+
260310 return index_headers
0 commit comments