|
| 1 | +# GPU Ensemble Simulation with Random Decay Rates |
| 2 | + |
| 3 | +In this tutorial, we demonstrate how to perform GPU-accelerated ensemble simulations using DiffEqGPU.jl. We model an exponential decay ODE: |
| 4 | +\[ u'(t) = -\lambda \, u(t) \] |
| 5 | +with the twist that each trajectory uses a random decay rate \(\lambda\) sampled uniformly from \([0.5, 1.5]\). |
| 6 | + |
| 7 | +## Setup |
| 8 | + |
| 9 | +We first define the ODE and set up an `EnsembleProblem` that randomizes the decay rate for each trajectory. |
| 10 | + |
| 11 | +```julia |
| 12 | +# Make sure you have the necessary packages installed |
| 13 | +# using Pkg |
| 14 | +# Pkg.add(["OrdinaryDiffEq", "DiffEqGPU", "CUDA", "Random", "Statistics", "Plots"]) |
| 15 | +# # Depending on your system, you might need to configure CUDA.jl: |
| 16 | +# # import Pkg; Pkg.build("CUDA") |
| 17 | +using OrdinaryDiffEq, DiffEqGPU, CUDA, Random, Statistics, Plots |
| 18 | + |
| 19 | +# Set a random seed for reproducibility |
| 20 | +Random.seed!(123) |
| 21 | + |
| 22 | +# Define the decay ODE: du/dt = -λ * u, with initial value u(0) = 1. |
| 23 | +decay(u, p, t) = -p * u |
| 24 | + |
| 25 | +# Setup initial condition and time span (using Float32 for GPU efficiency) |
| 26 | +u0 = 1.0f0 |
| 27 | +tspan = (0.0f0, 5.0f0) |
| 28 | +base_param = 1.0f0 |
| 29 | +prob = ODEProblem(decay, u0, tspan, base_param) |
| 30 | + |
| 31 | +# Define a probability function that randomizes λ for each ensemble member. |
| 32 | +# Each trajectory's λ is sampled uniformly from [0.5, 1.5]. |
| 33 | +prob_func = (prob, i, repeat) -> begin |
| 34 | + new_λ = 0.5f0 + 1.0f0 * rand() |
| 35 | + remake(prob, p = new_λ) |
| 36 | +end |
| 37 | + |
| 38 | +ensemble_prob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false) |
| 39 | +``` |
| 40 | + |
| 41 | +# Solving on GPU and CPU |
| 42 | + |
| 43 | +Here we solve the ensemble problem on both GPU and CPU. We use 10,000 trajectories with a fixed time step to facilitate performance comparison. |
| 44 | + |
| 45 | +```julia |
| 46 | +# Number of trajectories |
| 47 | +num_trajectories = 10_000 |
| 48 | + |
| 49 | +# Solve on GPU (check for CUDA availability) |
| 50 | +if CUDA.has_cuda() |
| 51 | + @info "Running GPU simulation..." |
| 52 | + gpu_sol = solve(ensemble_prob, GPUTsit5(), EnsembleGPUKernel(CUDA.CUDABackend()); |
| 53 | + trajectories = num_trajectories, dt = 0.01f0, adaptive = false) |
| 54 | +else |
| 55 | + @warn "CUDA not available. Skipping GPU simulation." |
| 56 | + gpu_sol = nothing |
| 57 | +end |
| 58 | + |
| 59 | +# Solve on CPU using multi-threading |
| 60 | +@info "Running CPU simulation..." |
| 61 | +cpu_sol = solve(ensemble_prob, Tsit5(), EnsembleThreads(); |
| 62 | + trajectories = num_trajectories, dt = 0.01f0, adaptive = false) |
| 63 | + |
| 64 | + |
| 65 | +``` |
| 66 | + |
| 67 | +# Performance Comparison |
| 68 | + |
| 69 | +We measure the performance of each simulation. (Note: The first run may include compilation time.) |
| 70 | + |
| 71 | +```julia |
| 72 | + |
| 73 | +# Warm-up (first run) for GPU if applicable |
| 74 | +if gpu_sol !== nothing |
| 75 | + @time solve(ensemble_prob, GPUTsit5(), EnsembleGPUKernel(CUDA.CUDABackend()); |
| 76 | + trajectories = num_trajectories, dt = 0.01f0, adaptive = false) |
| 77 | +end |
| 78 | + |
| 79 | +@time cpu_sol = solve(ensemble_prob, Tsit5(), EnsembleThreads(); |
| 80 | + trajectories = num_trajectories, dt = 0.01f0, adaptive = false) |
| 81 | +``` |
| 82 | + |
| 83 | +# Statistical Analysis and Visualization |
| 84 | +We analyze the ensemble by computing the mean and standard deviation of u(t)u(t) across trajectories, and then visualize the results. |
| 85 | + |
| 86 | +```julia |
| 87 | + |
| 88 | +# Assuming all solutions have the same time points (fixed dt & saveat) |
| 89 | +t_vals = cpu_sol[1].t |
| 90 | +num_times = length(t_vals) |
| 91 | +ensemble_vals = reduce(hcat, [sol.u for sol in cpu_sol]) # each column corresponds to one trajectory |
| 92 | + |
| 93 | +# Compute ensemble statistics |
| 94 | +mean_u = [mean(ensemble_vals[i, :]) for i in 1:num_times] |
| 95 | +std_u = [std(ensemble_vals[i, :]) for i in 1:num_times] |
| 96 | + |
| 97 | +# Plot the mean trajectory with ±1 standard deviation |
| 98 | +p1 = plot(t_vals, mean_u, ribbon = std_u, xlabel = "Time", ylabel = "u(t)", |
| 99 | + title = "Ensemble Mean and ±1σ", label = "Mean ± σ", legend = :topright) |
| 100 | + |
| 101 | +# Histogram of final values (u at t=5) |
| 102 | +final_vals = ensemble_vals[end, :] |
| 103 | +p2 = histogram(final_vals, bins = 30, xlabel = "Final u", ylabel = "Frequency", |
| 104 | + title = "Distribution of Final Values", label = "") |
| 105 | +plot(p1, p2, layout = (1,2), size = (900,400)) |
| 106 | + |
| 107 | + |
0 commit comments