Skip to content

Commit 15990ea

Browse files
Add simple GPU test to test suite
1 parent 6e3bd49 commit 15990ea

File tree

3 files changed

+44
-5
lines changed

3 files changed

+44
-5
lines changed

.buildkite/pipeline.yml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,17 @@ steps:
117117
queue: central
118118
slurm_ntasks: 1
119119

120-
- label: "GPU"
120+
# - label: "GPU"
121+
# command:
122+
# - "julia --project=test test/runtests.jl CuArray"
123+
# agents:
124+
# slurm_gres: "gpu:1"
125+
# queue: central
126+
# slurm_ntasks: 1
127+
128+
- label: "Simple GPU"
121129
command:
122-
- "julia --project=test test/runtests.jl CuArray"
130+
- "julia --project=test test/simple_gpu.jl"
123131
agents:
124132
slurm_gres: "gpu:1"
125133
queue: central

test/problems.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using DiffEqBase, ClimaTimeSteppers, LinearAlgebra, StaticArrays
22
using ClimaCore
3+
import ClimaCore.Device as Device
34
import ClimaCore.Domains as Domains
45
import ClimaCore.Geometry as Geometry
56
import ClimaCore.Meshes as Meshes
@@ -428,9 +429,16 @@ end
428429
429430
2D diffusion test problem. See [`2D diffusion problem`](@ref) for more details.
430431
"""
431-
function climacore_2Dheat_test_cts(::Type{FT}) where {FT}
432+
function climacore_2Dheat_test_cts(::Type{FT}; print_arr_type = false) where {FT}
432433
dss_tendency = true
433434

435+
device = Device.device()
436+
context = ClimaComms.SingletonCommsContext(device)
437+
438+
if print_arr_type
439+
@info "Array type: $(Device.device_array_type(device))"
440+
end
441+
434442
n_elem_x = 2
435443
n_elem_y = 2
436444
n_poly = 2
@@ -445,13 +453,17 @@ function climacore_2Dheat_test_cts(::Type{FT}) where {FT}
445453
Domains.IntervalDomain(Geometry.YPoint(FT(0)), Geometry.YPoint(FT(1)), periodic = true),
446454
)
447455
mesh = Meshes.RectilinearMesh(domain, n_elem_x, n_elem_y)
448-
topology = Topologies.Topology2D(mesh)
456+
topology = Topologies.Topology2D(context, mesh)
449457
quadrature = Spaces.Quadratures.GLL{n_poly + 1}()
450458
space = Spaces.SpectralElementSpace2D(topology, quadrature)
451459
(; x, y) = Fields.coordinate_field(space)
452460

453461
λ = (2 * FT(π))^2 * (n_x^2 + n_y^2)
454-
φ_sin_sin = @. sin(2 * FT(π) * n_x * x) * sin(2 * FT(π) * n_y * y)
462+
463+
# Revert once https://github.com/CliMA/ClimaCore.jl/issues/1097
464+
# is fixed
465+
# φ_sin_sin = @. sin(2 * FT(π) * n_x * x) * sin(2 * FT(π) * n_y * y)
466+
φ_sin_sin = @. sin(2 * π * n_x * x) * sin(2 * π * n_y * y)
455467

456468
init_state = Fields.FieldVector(; u = φ_sin_sin)
457469

test/simple_gpu.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using ClimaTimeSteppers
2+
import ClimaTimeSteppers as CTS
3+
import OrdinaryDiffEq as ODE
4+
5+
include(joinpath(@__DIR__, "problems.jl"))
6+
7+
function main(::Type{FT}) where {FT}
8+
alg_name = ARS343()
9+
test_case = climacore_2Dheat_test_cts(FT; print_arr_type = true)
10+
prob = test_case.split_prob
11+
alg = CTS.IMEXAlgorithm(alg_name, NewtonsMethod(; max_iters = 2))
12+
integrator = ODE.init(prob, alg; dt = FT(0.01))
13+
sol = ODE.solve!(integrator)
14+
@info "Done!"
15+
return integrator
16+
end
17+
18+
integrator = main(Float64)
19+
nothing

0 commit comments

Comments
 (0)