11using DiffEqBase, ClimaTimeSteppers, LinearAlgebra, StaticArrays
22using ClimaCore
3+ import ClimaCore. Device as Device
34import ClimaCore. Domains as Domains
45import ClimaCore. Geometry as Geometry
56import ClimaCore. Meshes as Meshes
@@ -21,16 +22,21 @@ u(t) = u_0 e^{αt}
2122
2223This is an in-place variant of the one from DiffEqProblemLibrary.jl.
2324"""
24- function linear_prob ()
25+ function linear_prob (:: Type{ArrayType} = Array) where {ArrayType}
2526 ODEProblem (
2627 IncrementingODEFunction {true} ((du, u, p, t, α = true , β = false ) -> (du .= α .* p .* u .+ β .* du)),
27- [1 / 2 ],
28+ ArrayType ( [1 / 2 ]) ,
2829 (0.0 , 1.0 ),
2930 1.01 ,
3031 )
3132end
32- function linear_prob_fe ()
33- ODEProblem (ForwardEulerODEFunction ((un, u, p, t, dt) -> (un .= u .+ dt .* p .* u)), [1.0 ], (0.0 , 1.0 ), - 0.2 )
33+ function linear_prob_fe (:: Type{ArrayType} = Array) where {ArrayType}
34+ ODEProblem (
35+ ForwardEulerODEFunction ((un, u, p, t, dt) -> (un .= u .+ dt .* p .* u)),
36+ ArrayType ([1.0 ]),
37+ (0.0 , 1.0 ),
38+ - 0.2 ,
39+ )
3440end
3541
3642function linear_prob_wfactt ()
@@ -96,27 +102,29 @@ with initial condition ``u_0=[0,1]``, parameter ``α=2``, and solution
96102u(t) = [cos(αt) sin(αt); -sin(αt) cos(αt) ] u_0
97103```
98104"""
99- function sincos_prob ()
105+ function sincos_prob (:: Type{ArrayType} = Array) where {ArrayType}
100106 ODEProblem (
101107 IncrementingODEFunction {true} ((du, u, p, t, α = true , β = false ) -> (du[1 ] = α * p * u[2 ] + β * du[1 ];
102108 du[2 ] = - α * p * u[1 ] + β * du[2 ])),
103- [0.0 , 1.0 ],
109+ ArrayType ( [0.0 , 1.0 ]) ,
104110 (0.0 , 1.0 ),
105111 2.0 ,
106112 )
107113end
108- function sincos_prob_fe ()
114+ function sincos_prob_fe (:: Type{ArrayType} = Array) where {ArrayType}
109115 ODEProblem (
110116 ForwardEulerODEFunction ((un, u, p, t, dt) -> (un[1 ] = u[1 ] + dt * p * u[2 ]; un[2 ] = u[2 ] - dt * p * u[1 ])),
111- [0.0 , 1.0 ],
117+ ArrayType ( [0.0 , 1.0 ]) ,
112118 (0.0 , 1.0 ),
113119 2.0 ,
114120 )
115121end
116122
117123function sincos_sol (u0, p, t)
118124 s, c = sincos (p * t)
119- [c s; - s c] * u0
125+ SC = similar (u0, (2 , 2 ))
126+ copyto! (SC, [c s; - s c])
127+ return SC * u0
120128end
121129
122130"""
428436
4294372D diffusion test problem. See [`2D diffusion problem`](@ref) for more details.
430438"""
431- function climacore_2Dheat_test_cts (:: Type{FT} ) where {FT}
439+ function climacore_2Dheat_test_cts (:: Type{FT} ; print_arr_type = false ) where {FT}
432440 dss_tendency = true
433441
442+ device = Device. device ()
443+ context = ClimaComms. SingletonCommsContext (device)
444+
445+ if print_arr_type
446+ @info " Array type: $(Device. device_array_type (device)) "
447+ end
448+
434449 n_elem_x = 2
435450 n_elem_y = 2
436451 n_poly = 2
@@ -445,13 +460,17 @@ function climacore_2Dheat_test_cts(::Type{FT}) where {FT}
445460 Domains. IntervalDomain (Geometry. YPoint (FT (0 )), Geometry. YPoint (FT (1 )), periodic = true ),
446461 )
447462 mesh = Meshes. RectilinearMesh (domain, n_elem_x, n_elem_y)
448- topology = Topologies. Topology2D (mesh)
463+ topology = Topologies. Topology2D (context, mesh)
449464 quadrature = Spaces. Quadratures. GLL {n_poly + 1} ()
450465 space = Spaces. SpectralElementSpace2D (topology, quadrature)
451466 (; x, y) = Fields. coordinate_field (space)
452467
453468 λ = (2 * FT (π))^ 2 * (n_x^ 2 + n_y^ 2 )
454- φ_sin_sin = @. sin (2 * FT (π) * n_x * x) * sin (2 * FT (π) * n_y * y)
469+
470+ # Revert once https://github.com/CliMA/ClimaCore.jl/issues/1097
471+ # is fixed
472+ # φ_sin_sin = @. sin(2 * FT(π) * n_x * x) * sin(2 * FT(π) * n_y * y)
473+ φ_sin_sin = @. sin (2 * π * n_x * x) * sin (2 * π * n_y * y)
455474
456475 init_state = Fields. FieldVector (; u = φ_sin_sin)
457476
0 commit comments