@@ -96,52 +96,59 @@ def make_field_plotter_from_bbox(
9696
9797class FieldPlotter :
9898 """
99+ .. autoattribute:: dimensions
100+ .. autoattribute:: npoints
101+ .. autoattribute:: points
102+
99103 .. automethod:: set_matplotlib_limits
100104 .. automethod:: show_scalar_in_matplotlib
101105 .. automethod:: show_scalar_in_mayavi
102106 .. automethod:: write_vtk_file
103107 """
104108
105109 dimensions : int
110+ npoints : int
111+ points : onp .Array2D [np .floating [Any ]]
112+
106113 a : onp .Array1D [np .floating [Any ]]
107114 b : onp .Array1D [np .floating [Any ]]
108-
109- nd_points : onp .Array2D [np .floating [Any ]]
110- points : onp .Array2D [np .floating [Any ]]
111- npoints : int
115+ nd_points : onp .ArrayND [np .floating [Any ]]
112116
113117 def __init__ (self ,
114118 center : onp .ToArray1D [np .floating [Any ]],
115119 extent : float | onp .Array1D [np .floating [Any ]] = 1 ,
116- npoints : int | tuple [int , ...] = 1000 ) -> None :
120+ npoints : int | tuple [int , ...] = 1000 ,
121+ points : onp .ArrayND [np .floating [Any ]] | None = None ) -> None :
117122 center = np .asarray (center )
118123 dim , = cast ("tuple[int]" , center .shape )
119124
120125 self .dimensions = dim
121126 self .a = a = center - 0.5 * extent
122127 self .b = b = center + 0.5 * extent
123128
124- from numbers import Number
125- if isinstance (npoints , (int , Number )):
126- npoints = dim * (npoints ,)
129+ if points is None :
130+ from numbers import Number
131+ if isinstance (npoints , (int , Number )):
132+ npoints = dim * (npoints ,)
133+ else :
134+ if len (npoints ) != dim :
135+ raise ValueError ("length of npoints must match dimension" )
136+
137+ for i in range (dim ):
138+ if npoints [i ] == 1 :
139+ a [i ] = center [i ]
140+
141+ mgrid_index = tuple (
142+ slice (a [i ], b [i ], 1j * npoints [i ])
143+ for i in range (dim ))
144+ mgrid = np .mgrid [mgrid_index ]
127145 else :
128- if len (npoints ) != dim :
129- raise ValueError ("length of npoints must match dimension" )
130-
131- for i in range (dim ):
132- if npoints [i ] == 1 :
133- a [i ] = center [i ]
134-
135- mgrid_index = tuple (
136- slice (a [i ], b [i ], 1j * npoints [i ])
137- for i in range (dim ))
138- mgrid = np .mgrid [mgrid_index ]
146+ mgrid = points
139147
140148 # (axis, point x idx, point y idx, ...)
141149 self .nd_points = mgrid
142-
143- self .points = self .nd_points .reshape (dim , - 1 ).copy ()
144- self .npoints = np .prod (npoints )
150+ self .points = mgrid .reshape (dim , - 1 ).copy ()
151+ self .npoints = mgrid .size
145152
146153 def _get_nontrivial_dims (self ) -> onp .Array1D [np .bool_ ]:
147154 return np .array (self .nd_points .shape [1 :]) != 1
0 commit comments