3636 GeometryValidator ,
3737 Validator ,
3838)
39- from gpgi ._typing import FieldMap , FloatT , Name
39+ from gpgi ._typing import D1 , D , F , FieldMap , Name
4040
4141if sys .version_info >= (3 , 13 ):
4242 LockType = Lock
@@ -90,7 +90,7 @@ class DepositionMethod(Enum):
9090@final
9191class GridCoordinatesValidator :
9292 @classmethod
93- def collect_exceptions (cls , data : Grid [FloatT ]) -> list [Exception ]:
93+ def collect_exceptions (cls , data : Grid [D , F ]) -> list [Exception ]:
9494 return FieldMapsValidatorHelper .collect_exceptions (
9595 data .coordinates ,
9696 require_sorted = True ,
@@ -101,7 +101,7 @@ def collect_exceptions(cls, data: Grid[FloatT]) -> list[Exception]:
101101@final
102102class GridFieldsValidator :
103103 @classmethod
104- def collect_exceptions (cls , data : Grid [FloatT ]) -> list [Exception ]:
104+ def collect_exceptions (cls , data : Grid [D , F ]) -> list [Exception ]:
105105 return FieldMapsValidatorHelper .collect_exceptions (
106106 data .fields ,
107107 required_attrs = {
@@ -112,19 +112,19 @@ def collect_exceptions(cls, data: Grid[FloatT]) -> list[Exception]:
112112 )
113113
114114
115- def is_uniformly_spaced (arr : np .ndarray [tuple [int ], np .dtype [FloatT ]]) -> bool :
115+ def is_uniformly_spaced (arr : np .ndarray [tuple [int ], np .dtype [F ]]) -> bool :
116116 dbase = np .diff (arr )
117117 return bool (dbase .std () / dbase .max () / dbase .size < 5 * np .finfo (arr .dtype ).eps )
118118
119119
120120@final
121- class Grid (Generic [FloatT ]):
121+ class Grid (Generic [D , F ]):
122122 def __init__ (
123123 self ,
124124 * ,
125125 geometry : Geometry ,
126- cell_edges : FieldMap [FloatT ],
127- fields : FieldMap [FloatT ] | None = None ,
126+ cell_edges : FieldMap [D1 , F ],
127+ fields : FieldMap [D , F ] | None = None ,
128128 ) -> None :
129129 r"""
130130 Define a Grid from cell left-edges and data fields.
@@ -140,25 +140,25 @@ def __init__(
140140 fields (keyword-only, optional): gpgi.typing.FieldMap
141141 """
142142 self .geometry = geometry
143- self .coordinates : FieldMap [FloatT ] = cell_edges
143+ self .coordinates : FieldMap [D1 , F ] = cell_edges
144144
145145 if fields is None :
146146 fields = {}
147- self .fields : FieldMap [FloatT ] = fields
147+ self .fields : FieldMap [D , F ] = fields
148148
149149 self .axes = tuple (self .coordinates .keys ())
150150 self ._validate ()
151- self .dtype : np .dtype [FloatT ] = self .coordinates [self .axes [0 ]].dtype
151+ self .dtype : np .dtype [F ] = self .coordinates [self .axes [0 ]].dtype
152152
153- self ._dx : NDArray [FloatT ] = np .full (
153+ self ._dx : NDArray [F ] = np .full (
154154 (3 ,), - 1 , dtype = self .coordinates [self .axes [0 ]].dtype
155155 )
156156 for i , ax in enumerate (self .axes ):
157157 if self .size == 1 or is_uniformly_spaced (self .coordinates [ax ]):
158158 # got a constant step in this direction, store it
159159 self ._dx [i ] = self .coordinates [ax ][1 ] - self .coordinates [ax ][0 ]
160160
161- _validators : list [type [Validator [Grid [FloatT ]]]] = [
161+ _validators : list [type [Validator [Grid [D , F ]]]] = [
162162 GeometryValidator ,
163163 BasicCoordinatesValidator ,
164164 GridCoordinatesValidator ,
@@ -186,17 +186,17 @@ def __repr__(self) -> str:
186186 )
187187
188188 @property
189- def cell_edges (self ) -> FieldMap [FloatT ]:
189+ def cell_edges (self ) -> FieldMap [D1 , F ]:
190190 r"""An alias for self.coordinates."""
191191 return self .coordinates
192192
193193 @cached_property
194- def cell_centers (self ) -> FieldMap [FloatT ]:
194+ def cell_centers (self ) -> FieldMap [D1 , F ]:
195195 r"""The positions of cell centers in each direction."""
196196 return {ax : 0.5 * (arr [1 :] + arr [:- 1 ]) for ax , arr in self .coordinates .items ()} # type: ignore [misc]
197197
198198 @cached_property
199- def cell_widths (self ) -> FieldMap [FloatT ]:
199+ def cell_widths (self ) -> FieldMap [D1 , F ]:
200200 r"""The width of cells, expressed as the difference between consecutive left edges."""
201201 return {ax : np .diff (arr ) for ax , arr in self .coordinates .items ()}
202202
@@ -220,7 +220,7 @@ def ndim(self) -> int:
220220 return len (self .axes )
221221
222222 @property
223- def cell_volumes (self ) -> NDArray [FloatT ]:
223+ def cell_volumes (self ) -> NDArray [F ]:
224224 r"""
225225 The generalized ND-volume of grid cells.
226226
@@ -241,7 +241,7 @@ def cell_volumes(self) -> NDArray[FloatT]:
241241@final
242242class ParticleSetCoordinatesValidator :
243243 @classmethod
244- def collect_exceptions (cls , data : ParticleSet [FloatT ]) -> list [Exception ]:
244+ def collect_exceptions (cls , data : ParticleSet [D , F ]) -> list [Exception ]:
245245 return FieldMapsValidatorHelper .collect_exceptions (
246246 data .coordinates ,
247247 require_shape_equality = True ,
@@ -250,13 +250,13 @@ def collect_exceptions(cls, data: ParticleSet[FloatT]) -> list[Exception]:
250250
251251
252252@final
253- class ParticleSet (Generic [FloatT ]):
253+ class ParticleSet (Generic [D , F ]):
254254 def __init__ (
255255 self ,
256256 * ,
257257 geometry : Geometry ,
258- coordinates : FieldMap [FloatT ],
259- fields : FieldMap [FloatT ] | None = None ,
258+ coordinates : FieldMap [D1 , F ],
259+ fields : FieldMap [D , F ] | None = None ,
260260 ) -> None :
261261 r"""
262262 Define a ParticleSet from point positions and data fields.
@@ -271,17 +271,17 @@ def __init__(
271271 fields (keyword-only, optional): gpgi.typing.FieldMap
272272 """
273273 self .geometry = geometry
274- self .coordinates : FieldMap [FloatT ] = coordinates
274+ self .coordinates : FieldMap [D1 , F ] = coordinates
275275
276276 if fields is None :
277277 fields = {}
278- self .fields : FieldMap [FloatT ] = fields
278+ self .fields : FieldMap [D , F ] = fields
279279
280280 self .axes = tuple (self .coordinates .keys ())
281281 self ._validate ()
282- self .dtype : np .dtype [FloatT ] = self .coordinates [self .axes [0 ]].dtype
282+ self .dtype : np .dtype [F ] = self .coordinates [self .axes [0 ]].dtype
283283
284- _validators : list [type [Validator [ParticleSet [FloatT ]]]] = [
284+ _validators : list [type [Validator [ParticleSet [D , F ]]]] = [
285285 GeometryValidator ,
286286 BasicCoordinatesValidator ,
287287 ParticleSetCoordinatesValidator ,
@@ -319,13 +319,13 @@ def ndim(self) -> int:
319319
320320
321321@final
322- class Dataset (Generic [FloatT ]):
322+ class Dataset (Generic [D , F ]):
323323 def __init__ (
324324 self ,
325325 * ,
326326 geometry : Geometry = Geometry .CARTESIAN ,
327- grid : Grid [FloatT ],
328- particles : ParticleSet [FloatT ] | None = None ,
327+ grid : Grid [D , F ],
328+ particles : ParticleSet [D , F ] | None = None ,
329329 metadata : dict [str , Any ] | None = None ,
330330 ) -> None :
331331 r"""
@@ -356,8 +356,8 @@ def __init__(
356356 coordinates = {ax : np .array ([], dtype = grid .dtype ) for ax in grid .axes },
357357 )
358358
359- self .grid : Grid [FloatT ] = grid
360- self .particles : ParticleSet [FloatT ] = particles
359+ self .grid : Grid [D , F ] = grid
360+ self .particles : ParticleSet [D , F ] = particles
361361
362362 self .boundary_recipes = BoundaryRegistry ()
363363 self .axes = self .grid .axes
@@ -401,23 +401,19 @@ def _validate(self) -> None:
401401
402402 def _get_padded_cell_edges (
403403 self ,
404- ) -> tuple [NDArray [FloatT ], NDArray [FloatT ], NDArray [FloatT ]]:
404+ ) -> tuple [NDArray [F ], NDArray [F ], NDArray [F ]]:
405405 edges = iter (self .grid .cell_edges .values ())
406406
407- def pad (a : NDArray [FloatT ]) -> NDArray [FloatT ]:
407+ def pad (a : NDArray [F ]) -> NDArray [F ]:
408408 dx = a [1 ] - a [0 ]
409409 return np .concatenate ([[a [0 ] - dx ], a , [a [- 1 ] + dx ]])
410410
411411 x1 = next (edges )
412412 cell_edges_x1 = pad (x1 )
413413 DTYPE = cell_edges_x1 .dtype
414414
415- cell_edges_x2 : np .ndarray [tuple [int , ...], np .dtype [FloatT ]] = np .empty (
416- 0 , DTYPE
417- )
418- cell_edges_x3 : np .ndarray [tuple [int , ...], np .dtype [FloatT ]] = np .empty (
419- 0 , DTYPE
420- )
415+ cell_edges_x2 : np .ndarray [tuple [int , ...], np .dtype [F ]] = np .empty (0 , DTYPE )
416+ cell_edges_x3 : np .ndarray [tuple [int , ...], np .dtype [F ]] = np .empty (0 , DTYPE )
421417 if self .grid .ndim >= 2 :
422418 cell_edges_x2 = pad (next (edges ))
423419 if self .grid .ndim == 3 :
@@ -427,7 +423,7 @@ def pad(a: NDArray[FloatT]) -> NDArray[FloatT]:
427423
428424 def _get_3D_particle_coordinates (
429425 self ,
430- ) -> tuple [NDArray [FloatT ], NDArray [FloatT ], NDArray [FloatT ]]:
426+ ) -> tuple [NDArray [F ], NDArray [F ], NDArray [F ]]:
431427 particle_coords = iter (self .particles .coordinates .values ())
432428 particles_x1 = next (particle_coords )
433429 DTYPE = particles_x1 .dtype
@@ -586,7 +582,7 @@ def deposit(
586582 weight_field : Name | None = None ,
587583 weight_field_boundaries : dict [Name , tuple [Name , Name ]] | None = None ,
588584 lock : Literal ["per-instance" ] | None | LockType = "per-instance" ,
589- ) -> NDArray [FloatT ]:
585+ ) -> NDArray [F ]:
590586 r"""
591587 Perform particle deposition and return the result as a grid field.
592588
@@ -810,9 +806,9 @@ def _sanitize_boundaries(self, boundaries: dict[Name, tuple[Name, Name]]) -> Non
810806
811807 def _apply_boundary_conditions (
812808 self ,
813- array : NDArray [FloatT ],
809+ array : NDArray [F ],
814810 boundaries : dict [Name , tuple [Name , Name ]],
815- weight_array : NDArray [FloatT ] | None ,
811+ weight_array : NDArray [F ] | None ,
816812 ) -> None :
817813 axes = list (self .grid .axes )
818814 for ax , bv in boundaries .items ():
0 commit comments