Skip to content

Commit a1a9145

Browse files
Add simple GPU test to test suite
1 parent a65fad6 commit a1a9145

File tree

3 files changed

+79
-32
lines changed

3 files changed

+79
-32
lines changed

.buildkite/pipeline.yml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,17 @@ steps:
117117
queue: central
118118
slurm_ntasks: 1
119119

120-
- label: "GPU"
120+
# - label: "GPU"
121+
# command:
122+
# - "julia --project=test test/runtests.jl CuArray"
123+
# agents:
124+
# slurm_gres: "gpu:1"
125+
# queue: central
126+
# slurm_ntasks: 1
127+
128+
- label: "Simple GPU"
121129
command:
122-
- "julia --project=test test/runtests.jl CuArray"
130+
- "julia --project=test test/simple_gpu.jl"
123131
agents:
124132
slurm_gres: "gpu:1"
125133
queue: central

test/problems.jl

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
using DiffEqBase, ClimaTimeSteppers, LinearAlgebra, StaticArrays
22
using ClimaCore
3+
import ClimaCore.Device as Device
4+
import ClimaCore.Domains as Domains
5+
import ClimaCore.Geometry as Geometry
6+
import ClimaCore.Meshes as Meshes
7+
import ClimaCore.Topologies as Topologies
8+
import ClimaCore.Spaces as Spaces
9+
import ClimaCore.Fields as Fields
10+
import ClimaCore.Operators as Operators
311

412
"""
513
Single variable linear ODE
@@ -421,9 +429,16 @@ end
421429
422430
2D diffusion test problem. See [`2D diffusion problem`](@ref) for more details.
423431
"""
424-
function climacore_2Dheat_test_cts(::Type{FT}) where {FT}
432+
function climacore_2Dheat_test_cts(::Type{FT}; print_arr_type = false) where {FT}
425433
dss_tendency = true
426434

435+
device = Device.device()
436+
context = ClimaComms.SingletonCommsContext(device)
437+
438+
if print_arr_type
439+
@info "Array type: $(Device.device_array_type(device))"
440+
end
441+
427442
n_elem_x = 2
428443
n_elem_y = 2
429444
n_poly = 2
@@ -433,38 +448,42 @@ function climacore_2Dheat_test_cts(::Type{FT}) where {FT}
433448
Δλ = FT(1) # denoted by Δλ̂ above
434449
t_end = FT(0.05) # denoted by t̂ above
435450

436-
domain = ClimaCore.Domains.RectangleDomain(
437-
ClimaCore.Domains.IntervalDomain(
438-
ClimaCore.Geometry.XPoint(FT(0)),
439-
ClimaCore.Geometry.XPoint(FT(1)),
451+
domain = Domains.RectangleDomain(
452+
Domains.IntervalDomain(
453+
Geometry.XPoint(FT(0)),
454+
Geometry.XPoint(FT(1)),
440455
periodic = true,
441456
),
442-
ClimaCore.Domains.IntervalDomain(
443-
ClimaCore.Geometry.YPoint(FT(0)),
444-
ClimaCore.Geometry.YPoint(FT(1)),
457+
Domains.IntervalDomain(
458+
Geometry.YPoint(FT(0)),
459+
Geometry.YPoint(FT(1)),
445460
periodic = true,
446461
),
447462
)
448-
mesh = ClimaCore.Meshes.RectilinearMesh(domain, n_elem_x, n_elem_y)
449-
topology = ClimaCore.Topologies.Topology2D(mesh)
450-
quadrature = ClimaCore.Spaces.Quadratures.GLL{n_poly + 1}()
451-
space = ClimaCore.Spaces.SpectralElementSpace2D(topology, quadrature)
452-
(; x, y) = ClimaCore.Fields.coordinate_field(space)
463+
mesh = Meshes.RectilinearMesh(domain, n_elem_x, n_elem_y)
464+
topology = Topologies.Topology2D(context, mesh)
465+
quadrature = Spaces.Quadratures.GLL{n_poly + 1}()
466+
space = Spaces.SpectralElementSpace2D(topology, quadrature)
467+
(; x, y) = Fields.coordinate_field(space)
453468

454469
λ = (2 * FT(π))^2 * (n_x^2 + n_y^2)
455-
φ_sin_sin = @. sin(2 * FT(π) * n_x * x) * sin(2 * FT(π) * n_y * y)
456470

457-
init_state = ClimaCore.Fields.FieldVector(; u = φ_sin_sin)
471+
# Revert once https://github.com/CliMA/ClimaCore.jl/issues/1097
472+
# is fixed
473+
# φ_sin_sin = @. sin(2 * FT(π) * n_x * x) * sin(2 * FT(π) * n_y * y)
474+
φ_sin_sin = @. sin(2 * π * n_x * x) * sin(2 * π * n_y * y)
475+
476+
init_state = Fields.FieldVector(; u = φ_sin_sin)
458477

459-
wdiv = ClimaCore.Operators.WeakDivergence()
460-
grad = ClimaCore.Operators.Gradient()
478+
wdiv = Operators.WeakDivergence()
479+
grad = Operators.Gradient()
461480
function T_exp!(tendency, state, _, t)
462481
@. tendency.u = wdiv(grad(state.u)) + f_0 * exp(-+ Δλ) * t) * φ_sin_sin
463-
dss_tendency && ClimaCore.Spaces.weighted_dss!(tendency.u)
482+
dss_tendency && Spaces.weighted_dss!(tendency.u)
464483
end
465484

466485
function dss!(state, _, t)
467-
dss_tendency || ClimaCore.Spaces.weighted_dss!(state.u)
486+
dss_tendency || Spaces.weighted_dss!(state.u)
468487
end
469488

470489
function analytic_sol(t)
@@ -492,25 +511,25 @@ function climacore_1Dheat_test_cts(::Type{FT}) where {FT}
492511
Δλ = FT(1) # denoted by Δλ̂ above
493512
t_end = FT(0.1) # denoted by t̂ above
494513

495-
domain = ClimaCore.Domains.IntervalDomain(
496-
ClimaCore.Geometry.ZPoint(FT(0)),
497-
ClimaCore.Geometry.ZPoint(FT(1)),
514+
domain = Domains.IntervalDomain(
515+
Geometry.ZPoint(FT(0)),
516+
Geometry.ZPoint(FT(1)),
498517
boundary_names = (:bottom, :top),
499518
)
500-
mesh = ClimaCore.Meshes.IntervalMesh(domain, nelems = n_elem_z)
501-
space = ClimaCore.Spaces.FaceFiniteDifferenceSpace(mesh)
502-
(; z) = ClimaCore.Fields.coordinate_field(space)
519+
mesh = Meshes.IntervalMesh(domain, nelems = n_elem_z)
520+
space = Spaces.FaceFiniteDifferenceSpace(mesh)
521+
(; z) = Fields.coordinate_field(space)
503522

504523
λ = (2 * FT(π) * n_z)^2
505524
φ_sin = @. sin(2 * FT(π) * n_z * z)
506525

507-
init_state = ClimaCore.Fields.FieldVector(; u = φ_sin)
526+
init_state = Fields.FieldVector(; u = φ_sin)
508527

509-
div = ClimaCore.Operators.DivergenceC2F(;
510-
bottom = ClimaCore.Operators.SetDivergence(FT(0)),
511-
top = ClimaCore.Operators.SetDivergence(FT(0)),
528+
div = Operators.DivergenceC2F(;
529+
bottom = Operators.SetDivergence(FT(0)),
530+
top = Operators.SetDivergence(FT(0)),
512531
)
513-
grad = ClimaCore.Operators.GradientF2C()
532+
grad = Operators.GradientF2C()
514533
function T_exp!(tendency, state, _, t)
515534
@. tendency.u = div(grad(state.u)) + f_0 * exp(-+ Δλ) * t) * φ_sin
516535
end

test/simple_gpu.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using ClimaTimeSteppers
2+
import ClimaTimeSteppers as CTS
3+
import OrdinaryDiffEq as ODE
4+
5+
include(joinpath(@__DIR__, "problems.jl"))
6+
7+
function main(::Type{FT}) where {FT}
8+
alg_name = ARS343()
9+
test_case = climacore_2Dheat_test_cts(FT; print_arr_type = true)
10+
prob = test_case.split_prob
11+
alg = CTS.IMEXAlgorithm(alg_name, NewtonsMethod(; max_iters = 2))
12+
integrator = ODE.init(prob, alg; dt = FT(0.01))
13+
sol = ODE.solve!(integrator)
14+
@info "Done!"
15+
return integrator
16+
end
17+
18+
integrator = main(Float64)
19+
nothing
20+

0 commit comments

Comments
 (0)