Skip to content

Commit 6ab868b

Browse files
committed
update GPU exercise
1 parent 14f4cc0 commit 6ab868b

File tree

2 files changed

+88
-45
lines changed

2 files changed

+88
-45
lines changed

exercise_03_intro_accelerated.jl

Lines changed: 69 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ begin
4242
using CairoMakie
4343
end
4444

45+
# ╔═╡ 1b9844e2-2e2b-47c2-8a5e-3ff3dbfacc16
46+
using OrdinaryDiffEqTsit5
47+
4548
# ╔═╡ 3e5c3c97-4401-41d4-a701-d9b24f9acdc6
4649
PlutoUI.TableOfContents(; depth=4)
4750

@@ -214,7 +217,7 @@ begin
214217
end
215218

216219
# ╔═╡ 0ad45abb-0d9f-4e8d-b097-b0b42ba024f7
217-
dt = dx^2 * dy^2 / (2.0 * a * (dx^2 + dy^2)) # Largest stable time step
220+
t_final = 3.0
218221

219222
# ╔═╡ 83042a1e-f964-483d-b316-a486cfabd7e0
220223
N = 64
@@ -234,28 +237,32 @@ $dU = a * dt * (\frac{U[i+1, j] - 2U[i,j] + U[i-1,j]}{dx^2} + \frac{U[i, j+1] -
234237
"""
235238

236239
# ╔═╡ 9eb166fa-360e-4a5d-a2ac-9113c2f264b3
237-
@kernel function diffuse(dU, @Const(U), a, dt, dx, dy)
240+
@kernel function heat_2D_kernel!(du, @Const(u), a, dx, dy)
238241
# implement me
239242
end
240243

241244
# ╔═╡ aa2d455e-c9fc-4a7b-b50e-77709481c2a7
242-
function diffuse!(U, a, dt, dx, dy)
243-
dU = zero(U)
244-
diffuse(get_backend(U))(dU, U, a, dt, dx, dy; ndrange=(N,N))
245-
U .+= dU
245+
function heat_2D!(du, u, (a, dx, dy), t)
246+
N, M = size(u)
247+
N = N - 2
248+
M = M - 2
249+
250+
# update boundary condition (wrap around)
251+
u[0, :] .= u[N, :]
252+
u[N+1, :] .= u[1, :]
253+
u[:, 0] .= u[:, N]
254+
u[:, N+1] .= u[:, 0]
246255

247-
# update boundary condition (wrap around)
248-
U[0, :] .= U[N, :]
249-
U[N+1, :] .= U[1, :]
250-
U[:, 0] .= U[:, N]
251-
U[:, N+1] .= U[:, 0]
252-
U
256+
kernel = heat_2D_kernel!(get_backend(du))
257+
kernel(du, u, a, dx, dy; ndrange=(N,M))
258+
259+
return nothing
253260
end
254261

255262
# ╔═╡ f912ee44-15ed-469f-b417-cf7d8d87146e
256263
answer_box(hint(md"""
257264
```julia
258-
@kernel function diffuse(dU, @Const(U), a, dt, dx, dy)
265+
@kernel function heat_2D_kernel(du, @Const(u), a, dx, dy)
259266
i, j = @index(Global, NTuple)
260267
out[i, j] = a * dt * (
261268
(U[i + 1, j] - 2 * U[i, j] + U[i - 1, j]) / dx^2 +
@@ -265,24 +272,55 @@ end
265272
```
266273
"""))
267274

268-
# ╔═╡ 2a986721-f513-488d-970e-4797f0de135f
269-
let
275+
# ╔═╡ 7147bba2-78ae-49fa-ae35-2c815ee188ae
276+
begin
270277
xs = 0:(N+1)
271278
ys = 0:(N+1)
272-
domain = OffsetArray(
273-
KernelAbstractions.zeros(backend, Float32, N+2, N+2),
274-
xs, ys)
275-
# TODO: Split out into initalize function
276-
parent(domain)[16:32, 16:32] .= 5
277-
278-
fig, ax, hm = heatmap(xs, ys, Array(parent(domain)))
279+
280+
u₀ = OffsetArray(
281+
KernelAbstractions.zeros(
282+
backend, Float32, N+2, N+2)
283+
, xs, ys)
284+
parent(u₀)[16:32, 16:32] .= 5
285+
286+
heatmap(xs, ys, parent(u₀))
287+
end
288+
289+
# ╔═╡ cdf61aff-ec98-4174-9d48-c287977742cb
290+
prob = ODEProblem(heat_2D!, u₀, (0.0, t_final), (a, dx, dy))
291+
292+
# ╔═╡ 8e963ae6-06fd-47fc-a1de-e884468de234
293+
sol = solve(prob, Tsit5(), saveat=0.2);
294+
295+
# ╔═╡ 78f7884c-3523-4d96-9dfc-fbe9de36a86b
296+
let
297+
idx = Observable(1)
298+
data = @lift Array(parent(sol.u[$idx]))
299+
fig, ax, hm = heatmap(xs, ys, data)
279300

280-
Makie.Record(fig, 1:250) do i
281-
diffuse!(domain, a, dt, dx, dy) # update data
282-
autolimits!(ax) # update limits
301+
Makie.Record(fig, 1:length(sol.u), framerate=5) do i
302+
idx[] = i
283303
end
284304
end
285305

306+
# ╔═╡ 2a986721-f513-488d-970e-4797f0de135f
307+
# let
308+
# xs = 0:(N+1)
309+
# ys = 0:(N+1)
310+
# domain = OffsetArray(
311+
# KernelAbstractions.zeros(backend, Float32, N+2, N+2),
312+
# xs, ys)
313+
# # TODO: Split out into initalize function
314+
# parent(domain)[16:32, 16:32] .= 5
315+
316+
# fig, ax, hm = heatmap(xs, ys, Array(parent(domain)))
317+
318+
# Makie.Record(fig, 1:250) do i
319+
# diffuse!(domain, a, dt, dx, dy) # update data
320+
# autolimits!(ax) # update limits
321+
# end
322+
# end
323+
286324
# ╔═╡ 00000000-0000-0000-0000-000000000001
287325
PLUTO_PROJECT_TOML_CONTENTS = """
288326
[deps]
@@ -314,7 +352,7 @@ oneAPI = "~2.0.3"
314352
PLUTO_MANIFEST_TOML_CONTENTS = """
315353
# This file is machine-generated - editing it directly is not advised
316354
317-
julia_version = "1.11.5"
355+
julia_version = "1.11.6"
318356
manifest_format = "2.0"
319357
project_hash = "81e4ea397ec4528af2733db8b8dce3926ca44ef1"
320358
@@ -2314,13 +2352,18 @@ version = "3.6.0+0"
23142352
# ╠═fc859cea-a41a-4d96-bf86-5a23bca19589
23152353
# ╟─1c76d376-ef91-4410-a981-d8a6dea3033f
23162354
# ╠═f7829706-3981-45b5-bdc3-d8b21155229a
2355+
# ╠═1b9844e2-2e2b-47c2-8a5e-3ff3dbfacc16
23172356
# ╠═0b20861c-995c-4890-81b9-98b8aca5095a
23182357
# ╠═0ad45abb-0d9f-4e8d-b097-b0b42ba024f7
23192358
# ╠═83042a1e-f964-483d-b316-a486cfabd7e0
23202359
# ╟─d2d925a0-8fd5-4345-8ee2-fa4dc4a75407
23212360
# ╠═9eb166fa-360e-4a5d-a2ac-9113c2f264b3
23222361
# ╠═aa2d455e-c9fc-4a7b-b50e-77709481c2a7
23232362
# ╟─f912ee44-15ed-469f-b417-cf7d8d87146e
2363+
# ╠═7147bba2-78ae-49fa-ae35-2c815ee188ae
2364+
# ╠═cdf61aff-ec98-4174-9d48-c287977742cb
2365+
# ╠═8e963ae6-06fd-47fc-a1de-e884468de234
2366+
# ╠═78f7884c-3523-4d96-9dfc-fbe9de36a86b
23242367
# ╠═2a986721-f513-488d-970e-4797f0de135f
23252368
# ╟─00000000-0000-0000-0000-000000000001
23262369
# ╟─00000000-0000-0000-0000-000000000002

lecture_04_interop.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ using CondaPkg; CondaPkg.add("seaborn")
4242
# ╔═╡ 1a7f20f9-a69c-49a2-a012-ed59657cc29f
4343
using PythonCall, RDatasets
4444

45+
# ╔═╡ e036e0b1-60f5-4670-9956-15e74d010ee9
46+
using MPI, Serialization, StaticArrays
47+
4548
# ╔═╡ 0e88ed74-261d-4aad-82dc-ed8076684406
4649
using Measurements
4750

@@ -200,6 +203,22 @@ macro mpi(np, expr)
200203
end
201204
end
202205

206+
# ╔═╡ fa98c58b-e61b-4762-a89f-58cf6b5a50d0
207+
@mpi np let
208+
using StaticArrays
209+
210+
MPI.Init()
211+
comm = MPI.COMM_WORLD
212+
213+
x = ones(SVector{3, Float64})
214+
sum = MPI.Allreduce([x], +, comm)
215+
216+
if MPI.Comm_rank(comm) == 0
217+
@show sum
218+
end
219+
nothing
220+
end
221+
203222
# ╔═╡ c739f61d-7104-4ae4-9934-fc98657fc2fc
204223
md"""
205224
Compute $\int_0^1 \frac{4}{1+x^2} dx = [4 * atan(x)]_0^1$ which evaluates to π
@@ -514,25 +533,6 @@ md"""
514533
- [2025 RSE Course](https://vchuravy.dev/rse-course)
515534
"""
516535
517-
# ╔═╡ fa98c58b-e61b-4762-a89f-58cf6b5a50d0
518-
@mpi np let
519-
using StaticArrays
520-
521-
MPI.Init()
522-
comm = MPI.COMM_WORLD
523-
524-
x = ones(SVector{3, Float64})
525-
sum = MPI.Allreduce([x], +, comm)
526-
527-
if MPI.Comm_rank(comm) == 0
528-
@show sum
529-
end
530-
nothing
531-
end
532-
533-
# ╔═╡ e036e0b1-60f5-4670-9956-15e74d010ee9
534-
using MPI, Serialization, StaticArrays
535-
536536
# ╔═╡ 00000000-0000-0000-0000-000000000001
537537
PLUTO_PROJECT_TOML_CONTENTS = """
538538
[deps]

0 commit comments

Comments
 (0)