11using DiffEqBase, ClimaTimeSteppers, LinearAlgebra, StaticArrays
22using ClimaCore
3+ import ClimaCore. Device as Device
34import ClimaCore. Domains as Domains
45import ClimaCore. Geometry as Geometry
56import ClimaCore. Meshes as Meshes
@@ -21,16 +22,16 @@ u(t) = u_0 e^{αt}
2122
2223This is an in-place variant of the one from DiffEqProblemLibrary.jl.
2324"""
24- function linear_prob ()
25+ function linear_prob (:: Type{ArrayType} = Array) where {ArrayType}
2526 ODEProblem (
2627 IncrementingODEFunction {true} ((du, u, p, t, α = true , β = false ) -> (du .= α .* p .* u .+ β .* du)),
27- [1 / 2 ],
28+ ArrayType ( [1 / 2 ]) ,
2829 (0.0 , 1.0 ),
2930 1.01 ,
3031 )
3132end
32- function linear_prob_fe ()
33- ODEProblem (ForwardEulerODEFunction ((un, u, p, t, dt) -> (un .= u .+ dt .* p .* u)), [1.0 ], (0.0 , 1.0 ), - 0.2 )
33+ function linear_prob_fe (:: Type{ArrayType} = Array) where {ArrayType}
34+ ODEProblem (ForwardEulerODEFunction ((un, u, p, t, dt) -> (un .= u .+ dt .* p .* u)), ArrayType ( [1.0 ]) , (0.0 , 1.0 ), - 0.2 )
3435end
3536
3637function linear_prob_wfactt ()
@@ -96,27 +97,29 @@ with initial condition ``u_0=[0,1]``, parameter ``α=2``, and solution
9697u(t) = [cos(αt) sin(αt); -sin(αt) cos(αt) ] u_0
9798```
9899"""
99- function sincos_prob ()
100+ function sincos_prob (:: Type{ArrayType} = Array) where {ArrayType}
100101 ODEProblem (
101102 IncrementingODEFunction {true} ((du, u, p, t, α = true , β = false ) -> (du[1 ] = α * p * u[2 ] + β * du[1 ];
102103 du[2 ] = - α * p * u[1 ] + β * du[2 ])),
103- [0.0 , 1.0 ],
104+ ArrayType ( [0.0 , 1.0 ]) ,
104105 (0.0 , 1.0 ),
105106 2.0 ,
106107 )
107108end
108- function sincos_prob_fe ()
109+ function sincos_prob_fe (:: Type{ArrayType} = Array) where {ArrayType}
109110 ODEProblem (
110111 ForwardEulerODEFunction ((un, u, p, t, dt) -> (un[1 ] = u[1 ] + dt * p * u[2 ]; un[2 ] = u[2 ] - dt * p * u[1 ])),
111- [0.0 , 1.0 ],
112+ ArrayType ( [0.0 , 1.0 ]) ,
112113 (0.0 , 1.0 ),
113114 2.0 ,
114115 )
115116end
116117
117118function sincos_sol (u0, p, t)
118119 s, c = sincos (p * t)
119- [c s; - s c] * u0
120+ SC = similar (u0, (2 ,2 ))
121+ copyto! (SC, [c s; - s c])
122+ return SC * u0
120123end
121124
122125"""
428431
4294322D diffusion test problem. See [`2D diffusion problem`](@ref) for more details.
430433"""
431- function climacore_2Dheat_test_cts (:: Type{FT} ) where {FT}
434+ function climacore_2Dheat_test_cts (:: Type{FT} ; print_arr_type = false ) where {FT}
432435 dss_tendency = true
433436
437+ device = Device. device ()
438+ context = ClimaComms. SingletonCommsContext (device)
439+
440+ if print_arr_type
441+ @info " Array type: $(Device. device_array_type (device)) "
442+ end
443+
434444 n_elem_x = 2
435445 n_elem_y = 2
436446 n_poly = 2
@@ -445,13 +455,17 @@ function climacore_2Dheat_test_cts(::Type{FT}) where {FT}
445455 Domains. IntervalDomain (Geometry. YPoint (FT (0 )), Geometry. YPoint (FT (1 )), periodic = true ),
446456 )
447457 mesh = Meshes. RectilinearMesh (domain, n_elem_x, n_elem_y)
448- topology = Topologies. Topology2D (mesh)
458+ topology = Topologies. Topology2D (context, mesh)
449459 quadrature = Spaces. Quadratures. GLL {n_poly + 1} ()
450460 space = Spaces. SpectralElementSpace2D (topology, quadrature)
451461 (; x, y) = Fields. coordinate_field (space)
452462
453463 λ = (2 * FT (π))^ 2 * (n_x^ 2 + n_y^ 2 )
454- φ_sin_sin = @. sin (2 * FT (π) * n_x * x) * sin (2 * FT (π) * n_y * y)
464+
465+ # Revert once https://github.com/CliMA/ClimaCore.jl/issues/1097
466+ # is fixed
467+ # φ_sin_sin = @. sin(2 * FT(π) * n_x * x) * sin(2 * FT(π) * n_y * y)
468+ φ_sin_sin = @. sin (2 * π * n_x * x) * sin (2 * π * n_y * y)
455469
456470 init_state = Fields. FieldVector (; u = φ_sin_sin)
457471
0 commit comments