3535from meshmode .array_context import (PyOpenCLArrayContext ,
3636 PytatoPyOpenCLArrayContext )
3737from arraycontext import (
38- freeze , thaw ,
3938 ArrayContainer ,
4039 map_array_container ,
4140 with_container_arithmetic ,
5756# {{{ discretization
5857
5958def parametrization_derivative (actx , discr ):
60- thawed_nodes = thaw (discr .nodes (), actx )
59+ thawed_nodes = actx . thaw (discr .nodes ())
6160
6261 from meshmode .discretization import num_reference_derivative
6362 result = np .zeros ((discr .ambient_dim , discr .dim ), dtype = object )
@@ -175,17 +174,17 @@ def get_discr(self, where):
175174
176175 @memoize_method
177176 def parametrization_derivative (self ):
178- return freeze (
177+ return self . _setup_actx . freeze (
179178 parametrization_derivative (self ._setup_actx , self .volume_discr ))
180179
181180 @memoize_method
182181 def vol_jacobian (self ):
183- [a , b ], [c , d ] = thaw (self .parametrization_derivative (), self . _setup_actx )
184- return freeze (a * d - b * c )
182+ [a , b ], [c , d ] = self . _setup_actx . thaw (self .parametrization_derivative ())
183+ return self . _setup_actx . freeze (a * d - b * c )
185184
186185 @memoize_method
187186 def inverse_parametrization_derivative (self ):
188- [a , b ], [c , d ] = thaw (self .parametrization_derivative (), self . _setup_actx )
187+ [a , b ], [c , d ] = self . _setup_actx . thaw (self .parametrization_derivative ())
189188
190189 result = np .zeros ((2 , 2 ), dtype = object )
191190 det = a * d - b * c
@@ -194,13 +193,13 @@ def inverse_parametrization_derivative(self):
194193 result [1 , 0 ] = - c / det
195194 result [1 , 1 ] = a / det
196195
197- return freeze (result )
196+ return self . _setup_actx . freeze (result )
198197
199198 def zeros (self , actx ):
200199 return self .volume_discr .zeros (actx )
201200
202201 def grad (self , vec ):
203- ipder = thaw (self .inverse_parametrization_derivative (), vec . array_context )
202+ ipder = vec . array_context . thaw (self .inverse_parametrization_derivative ())
204203
205204 from meshmode .discretization import num_reference_derivative
206205 dref = [
@@ -222,15 +221,15 @@ def normal(self, where):
222221 ((a ,), (b ,)) = parametrization_derivative (self ._setup_actx , bdry_discr )
223222
224223 nrm = 1 / (a ** 2 + b ** 2 )** 0.5
225- return freeze (flat_obj_array (b * nrm , - a * nrm ))
224+ return self . _setup_actx . freeze (flat_obj_array (b * nrm , - a * nrm ))
226225
227226 @memoize_method
228227 def face_jacobian (self , where ):
229228 bdry_discr = self .get_discr (where )
230229
231230 ((a ,), (b ,)) = parametrization_derivative (self ._setup_actx , bdry_discr )
232231
233- return freeze ((a ** 2 + b ** 2 )** 0.5 )
232+ return self . _setup_actx . freeze ((a ** 2 + b ** 2 )** 0.5 )
234233
235234 @memoize_method
236235 def get_inverse_mass_matrix (self , grp , dtype ):
@@ -261,7 +260,7 @@ def inverse_mass(self, vec):
261260 tagged = (FirstAxisIsElementsTag (),)
262261 ) for grp , vec_i in zip (discr .groups , vec )
263262 )
264- ) / thaw (self .vol_jacobian (), actx )
263+ ) / actx . thaw (self .vol_jacobian ())
265264
266265 @memoize_method
267266 def get_local_face_mass_matrix (self , afgrp , volgrp , dtype ):
@@ -300,7 +299,7 @@ def face_mass(self, vec):
300299 all_faces_discr = all_faces_conn .to_discr
301300 vol_discr = all_faces_conn .from_discr
302301
303- fj = thaw (self .face_jacobian ("all_faces" ), vec . array_context )
302+ fj = vec . array_context . thaw (self .face_jacobian ("all_faces" ))
304303 vec = vec * fj
305304
306305 assert len (all_faces_discr .groups ) == len (vol_discr .groups )
@@ -367,7 +366,7 @@ def wave_flux(actx, discr, c, q_tpair):
367366 u = q_tpair .u
368367 v = q_tpair .v
369368
370- normal = thaw (discr .normal (q_tpair .where ), actx )
369+ normal = actx . thaw (discr .normal (q_tpair .where ))
371370
372371 flux_weak = WaveState (
373372 u = np .dot (v .avg , normal ),
@@ -422,7 +421,7 @@ def bump(actx, discr, t=0):
422421 source_width = 0.05
423422 source_omega = 3
424423
425- nodes = thaw (discr .volume_discr .nodes (), actx )
424+ nodes = actx . thaw (discr .volume_discr .nodes ())
426425 center_dist = flat_obj_array ([
427426 nodes [0 ] - source_center [0 ],
428427 nodes [1 ] - source_center [1 ],
@@ -492,8 +491,8 @@ def rhs(t, q):
492491 compiled_rhs = actx_rhs .compile (rhs )
493492
494493 def rhs_wrapper (t , q ):
495- r = compiled_rhs (t , thaw (freeze (q , actx_outer ), actx_rhs ))
496- return thaw (freeze (r , actx_rhs ), actx_outer )
494+ r = compiled_rhs (t , actx_rhs . thaw (actx_outer . freeze (q ) ))
495+ return actx_outer . thaw (actx_rhs . freeze (r ) )
497496
498497 t = np .float64 (0 )
499498 t_final = 3
0 commit comments