Skip to content

Commit e53706e

Browse files
Merge pull request #346 from ParyaRoustaee/add-random-decay-tutorial
updated the random_decay tutorial with StaticArrays and improve benchmark and plots
2 parents 8ea3c51 + b2f7018 commit e53706e

File tree

2 files changed

+248
-50
lines changed

2 files changed

+248
-50
lines changed

docs/src/examples/gpu_ensemble_random_decay.md

Lines changed: 246 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,106 +2,303 @@
22

33
In this tutorial, we demonstrate how to perform GPU-accelerated ensemble simulations using DiffEqGPU.jl. We model an exponential decay ODE:
44
\[ u'(t) = -\lambda \, u(t) \]
5-
with the twist that each trajectory uses a random decay rate \(\lambda\) sampled uniformly from \([0.5, 1.5]\).
5+
with the twist that each trajectory uses a random decay rate \(\lambda\) sampled uniformly from \([0.5, 1.5]\). This version uses `StaticArrays` for the state vector, which is often more robust and performant for small ODEs on the GPU.
66

77
## Setup
88

9-
We first define the ODE and set up an `EnsembleProblem` that randomizes the decay rate for each trajectory.
9+
We first load the necessary packages, define the ODE using `StaticArrays`, and set up an `EnsembleProblem` that randomizes the decay rate for each trajectory.
1010

11-
```julia
11+
```@example decay
1212
# Make sure you have the necessary packages installed
1313
# using Pkg
14-
# Pkg.add(["OrdinaryDiffEq", "DiffEqGPU", "CUDA", "Random", "Statistics", "Plots"])
14+
# Pkg.add(["OrdinaryDiffEq", "DiffEqGPU", "CUDA", "StaticArrays", "Random", "Statistics", "Plots"])
1515
# # Depending on your system, you might need to configure CUDA.jl:
1616
# # import Pkg; Pkg.build("CUDA")
17-
using OrdinaryDiffEq, DiffEqGPU, CUDA, Random, Statistics, Plots
17+
using OrdinaryDiffEq, DiffEqGPU, CUDA, StaticArrays, Random, Statistics, Plots
1818
1919
# Set a random seed for reproducibility
2020
Random.seed!(123)
2121
22-
# Define the decay ODE: du/dt = -λ * u, with initial value u(0) = 1.
23-
decay(u, p, t) = -p * u
22+
# Define the decay ODE using the OUT-OF-PLACE form for StaticArrays:
23+
# f(u, p, t) should return a new SVector representing the derivative du/dt.
24+
# This form is generally preferred for StaticArrays on the GPU.
25+
function decay_static(u::SVector, p, t)
26+
λ = p[1] # Parameter is expected as a scalar or single-element container
27+
return @SVector [-λ * u[1]]
28+
end
29+
30+
# Setup initial condition as a 1-element SVector (Static Array).
31+
# Using StaticArrays explicitly helps the GPU compiler generate efficient, static code.
32+
u0 = @SVector [1.0f0]
2433
25-
# Setup initial condition and time span (using Float32 for GPU efficiency)
26-
u0 = 1.0f0
34+
# Define time span (using Float32 for GPU efficiency)
2735
tspan = (0.0f0, 5.0f0)
28-
base_param = 1.0f0
29-
prob = ODEProblem(decay, u0, tspan, base_param)
36+
37+
# Define the base parameter (will be overridden by prob_func)
38+
# We wrap it in an SVector to match how parameters might be handled internally,
39+
# though a scalar Float32 often works too. Using SVector can sometimes avoid type issues.
40+
base_param = @SVector [1.0f0]
41+
42+
# Create the ODEProblem using the static function and SVector initial condition
43+
prob = ODEProblem{false}(decay_static, u0, tspan, base_param) # Use {false} for out-of-place
3044
3145
# Define a probability function that randomizes λ for each ensemble member.
3246
# Each trajectory's λ is sampled uniformly from [0.5, 1.5].
3347
prob_func = (prob, i, repeat) -> begin
34-
new_λ = 0.5f0 + 1.0f0 * rand()
35-
remake(prob, p = new_λ)
48+
new_λ = 0.5f0 + 1.0f0 * rand(Float32)
49+
remake(prob, p = @SVector [new_λ])
3650
end
3751
3852
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false)
3953
```
4054

4155
# Solving on GPU and CPU
4256

43-
Here we solve the ensemble problem on both GPU and CPU. We use 10,000 trajectories with a fixed time step to facilitate performance comparison.
57+
Here we solve the ensemble problem on both GPU and CPU. We use 10,000 trajectories with a fixed time step to facilitate performance comparison. For performance benchmarking, we initially solve without saving every step (save_everystep=false).
4458

45-
```julia
59+
```@example decay
4660
# Number of trajectories
4761
num_trajectories = 10_000
4862
49-
# Solve on GPU (check for CUDA availability)
50-
if CUDA.has_cuda()
51-
@info "Running GPU simulation..."
63+
# --- GPU Simulation ---
64+
gpu_sol_perf = nothing # Initialize variable for performance run
65+
if CUDA.has_cuda() && CUDA.functional()
66+
@info "Running GPU simulation (initial run for performance, includes compilation)..."
67+
# Use EnsembleGPUKernel with the CUDABackend.
68+
# GPUTsit5 is suitable for non-stiff ODEs on the GPU.
69+
# save_everystep=false reduces memory overhead and transfer time if only final states are needed.
5270
gpu_sol = solve(ensemble_prob, GPUTsit5(), EnsembleGPUKernel(CUDA.CUDABackend());
53-
trajectories = num_trajectories, dt = 0.01f0, adaptive = false)
71+
trajectories = num_trajectories,
72+
save_everystep = false, # Crucial for performance measurement
73+
dt = 0.01f0, adaptive = false)
5474
else
5575
@warn "CUDA not available. Skipping GPU simulation."
5676
gpu_sol = nothing
5777
end
5878
59-
# Solve on CPU using multi-threading
60-
@info "Running CPU simulation..."
79+
# --- CPU Simulation ---
80+
@info "Running CPU simulation (initial run for performance, includes compilation)..."
81+
# Use EnsembleThreads for multi-threaded CPU execution.
82+
# Tsit5 is the CPU counterpart to GPUTsit5.
83+
# Match GPU saving options for fair comparison.
6184
cpu_sol = solve(ensemble_prob, Tsit5(), EnsembleThreads();
62-
trajectories = num_trajectories, dt = 0.01f0, adaptive = false)
63-
64-
85+
trajectories = num_trajectories,
86+
save_everystep = false, # Match GPU setting
87+
dt = 0.01f0, adaptive = false)
6588
```
6689

6790
# Performance Comparison
6891

69-
We measure the performance of each simulation. (Note: The first run may include compilation time.)
70-
71-
```julia
92+
We re-run the simulations using @time to get a cleaner measurement of the execution time, excluding the initial compilation overhead.
7293

73-
# Warm-up (first run) for GPU if applicable
74-
if gpu_sol !== nothing
94+
```@example decay
95+
# --- GPU Timing (Second Run) ---
96+
if gpu_sol_perf !== nothing
97+
@info "Timing GPU simulation (second run, no data saving)..."
7598
@time solve(ensemble_prob, GPUTsit5(), EnsembleGPUKernel(CUDA.CUDABackend());
76-
trajectories = num_trajectories, dt = 0.01f0, adaptive = false)
99+
trajectories = num_trajectories, save_everystep = false,
100+
dt = 0.01f0, adaptive = false)
101+
else
102+
@info "Skipping GPU timing (CUDA not available)."
77103
end
78104
105+
# --- CPU Timing (Second Run) ---
106+
@info "Timing CPU simulation (second run, no data saving)..."
79107
@time cpu_sol = solve(ensemble_prob, Tsit5(), EnsembleThreads();
80-
trajectories = num_trajectories, dt = 0.01f0, adaptive = false)
108+
trajectories = num_trajectories, save_everystep = false,
109+
dt = 0.01f0, adaptive = false)
110+
111+
# Note: The first @time includes compilation and setup, the second is more representative
112+
# of the pure computation time for subsequent runs. Expect GPU to be significantly
113+
# faster for a large number of trajectories like 10,000.
114+
```
115+
116+
# CPU Statistical Analysis and Visualization
117+
118+
To visualize the evolution of the ensemble statistics (mean and standard deviation) over time using the *CPU results*, we need the solutions at multiple time points. We re-solve the problem on the CPU, this time saving the results at each step (save_everystep=true). We then process the results and plot them.
119+
120+
```@example decay
121+
# Re-solve on CPU, saving all steps for plotting
122+
@info "Re-solving CPU simulation to collect data for plotting..."
123+
cpu_sol_plot = solve(ensemble_prob, Tsit5(), EnsembleThreads();
124+
trajectories = num_trajectories,
125+
save_everystep = true, # Save data at each dt step
126+
dt = 0.01f0,
127+
adaptive = false)
128+
129+
# Extract time points from the first trajectory's solution (assuming all are same)
130+
t_vals_cpu = cpu_sol_plot[1].t
131+
num_times_cpu = length(t_vals_cpu)
132+
133+
# Create a matrix to hold the results: rows=time, columns=trajectories
134+
# Initialize with NaN in case some trajectories fail
135+
ensemble_vals_cpu = fill(NaN32, num_times_cpu, num_trajectories) # Use Float32
136+
137+
# Extract the state value (u[1]) from each trajectory at each time point
138+
for i in 1:num_trajectories
139+
# Check if the trajectory simulation was successful and data looks valid
140+
if cpu_sol_plot[i].retcode == ReturnCode.Success &&
141+
length(cpu_sol_plot[i].u) == num_times_cpu
142+
# sol.u is a Vector{SVector{1, Float32}}. We need the element from each SVector.
143+
ensemble_vals_cpu[:, i] .= getindex.(cpu_sol_plot[i].u, 1)
144+
else
145+
@warn "CPU Trajectory $i failed or had unexpected length. Retcode: $(cpu_sol_plot[i].retcode). Length: $(length(cpu_sol_plot[i].u)). Skipping."
146+
# Column remains NaN
147+
end
148+
end
149+
150+
# Filter out failed trajectories (columns with NaN)
151+
successful_traj_indices_cpu = findall(
152+
j -> !all(isnan, view(ensemble_vals_cpu, :, j)), 1:num_trajectories)
153+
num_successful_cpu = length(successful_traj_indices_cpu)
154+
155+
if num_successful_cpu == 0
156+
@error "No successful CPU trajectories to analyze!"
157+
else
158+
if num_successful_cpu < num_trajectories
159+
@warn "$(num_trajectories - num_successful_cpu) CPU trajectories failed. Analysis based on $num_successful_cpu trajectories."
160+
ensemble_vals_cpu = ensemble_vals_cpu[:, successful_traj_indices_cpu] # Keep only successful ones
161+
end
162+
163+
# Compute ensemble statistics over successful CPU trajectories
164+
mean_u_cpu = mapslices(mean, ensemble_vals_cpu, dims = 2)[:]
165+
std_u_cpu = mapslices(std, ensemble_vals_cpu, dims = 2)[:]
166+
167+
# --- Plotting CPU Results ---
168+
p1_cpu = plot(
169+
t_vals_cpu, mean_u_cpu, ribbon = std_u_cpu, xlabel = "Time (t)", ylabel = "u(t)",
170+
title = "CPU Ensemble Mean ±1σ ($num_successful_cpu Trajectories)",
171+
label = "Mean u(t)", fillalpha = 0.3, lw = 2)
172+
173+
final_vals_cpu = ensemble_vals_cpu[end, :]
174+
p2_cpu = histogram(final_vals_cpu, bins = 30, normalize = :probability,
175+
xlabel = "Final u(T)", ylabel = "Probability Density",
176+
title = "CPU Distribution of Final Values (t=$(tspan[2]))",
177+
label = "", legend = false)
178+
179+
plot_cpu = plot(
180+
p1_cpu, p2_cpu, layout = (1, 2), size = (1000, 450), legend = :outertopright)
181+
@info "Displaying CPU analysis plot..."
182+
display(plot_cpu)
183+
end
81184
```
82185

83-
# Statistical Analysis and Visualization
84-
We analyze the ensemble by computing the mean and standard deviation of u(t)u(t) across trajectories, and then visualize the results.
186+
# GPU Statistical Analysis and Visualization
187+
188+
Similarly, we can analyze the results from the *GPU simulation*. This requires re-running the simulation to save the time series data and then transferring the data from the GPU memory to the CPU RAM for analysis and plotting using standard tools. Note that this data transfer can be a significant bottleneck for large numbers of trajectories or time steps.
189+
190+
```@example decay
191+
# Check if GPU simulation was successful initially before proceeding
192+
gpu_analysis_plot = nothing # Initialize plot variable
193+
if gpu_sol_perf !== nothing && CUDA.has_cuda() && CUDA.functional()
194+
@info "Re-solving GPU simulation to collect data for plotting..."
195+
# Add @time to see the impact of saving data
196+
@time gpu_sol_plot = solve(
197+
ensemble_prob, GPUTsit5(), EnsembleGPUKernel(CUDA.CUDABackend());
198+
trajectories = num_trajectories,
199+
save_everystep = true, # <<-- Save data at each dt step on GPU
200+
dt = 0.01f0,
201+
adaptive = false)
85202
86-
```julia
203+
# --- Data Transfer and Analysis ---
204+
# The result gpu_sol_plot should be an EnsembleSolution containing a Vector{ODESolution}
205+
# Accessing it might implicitly transfer, or we can use Array()
206+
@info "Transferring GPU solution objects (if needed) and processing..."
207+
# Let's try accessing .u directly first, assuming it holds the Vector{ODESolution}
208+
# If this fails, we might try Array(gpu_sol_plot) -> Vector{ODESolution}
209+
solutions_vector = gpu_sol_plot.u
87210
88-
# Assuming all solutions have the same time points (fixed dt & saveat)
89-
t_vals = cpu_sol[1].t
90-
num_times = length(t_vals)
91-
ensemble_vals = reduce(hcat, [sol.u for sol in cpu_sol]) # each column corresponds to one trajectory
211+
# Check if the transfer actually happened / if we have the right type
212+
if !(solutions_vector isa AbstractVector{<:ODESolution})
213+
@warn "gpu_sol_plot.u is not a Vector{ODESolution}. Trying Array(gpu_sol_plot)..."
214+
# This might explicitly trigger the transfer and construction of ODESolution objects on CPU
215+
# Note: This might be slow/memory intensive!
216+
@time solutions_vector = Array(gpu_sol_plot)
217+
if !(solutions_vector isa AbstractVector{<:ODESolution})
218+
@error "Could not obtain Vector{ODESolution} from GPU result. Type is $(typeof(solutions_vector)). Aborting GPU analysis."
219+
solutions_vector = nothing # Mark as failed
220+
end
221+
end
92222
93-
# Compute ensemble statistics
94-
mean_u = [mean(ensemble_vals[i, :]) for i in 1:num_times]
95-
std_u = [std(ensemble_vals[i, :]) for i in 1:num_times]
223+
if solutions_vector !== nothing
224+
# Extract time points from the first successful trajectory's solution
225+
first_successful_gpu_idx = findfirst(
226+
sol -> sol.retcode == ReturnCode.Success, solutions_vector)
96227
97-
# Plot the mean trajectory with ±1 standard deviation
98-
p1 = plot(t_vals, mean_u, ribbon = std_u, xlabel = "Time", ylabel = "u(t)",
99-
title = "Ensemble Mean and ±1σ", label = "Mean ± σ", legend = :topright)
228+
if first_successful_gpu_idx === nothing
229+
@error "No successful GPU trajectories found in the returned solutions vector!"
230+
else
231+
t_vals_gpu = solutions_vector[first_successful_gpu_idx].t
232+
num_times_gpu = length(t_vals_gpu)
100233
101-
# Histogram of final values (u at t=5)
102-
final_vals = ensemble_vals[end, :]
103-
p2 = histogram(final_vals, bins = 30, xlabel = "Final u", ylabel = "Frequency",
104-
title = "Distribution of Final Values", label = "")
105-
plot(p1, p2, layout = (1,2), size = (900,400))
234+
# Create a matrix to hold the results from GPU (now on CPU)
235+
ensemble_vals_gpu = fill(NaN32, num_times_gpu, num_trajectories) # Use Float32
106236
237+
# Extract the state value (u[1]) from each trajectory
238+
num_processed = 0
239+
for i in 1:num_trajectories
240+
sol = solutions_vector[i] # Access the i-th ODESolution
241+
if sol.retcode == ReturnCode.Success
242+
# Check consistency of time points (optional but good)
243+
if length(sol.t) == num_times_gpu # && sol.t == t_vals_gpu (can be slow check)
244+
# sol.u is likely Vector{SVector{1, Float32}} after transfer
245+
ensemble_vals_gpu[:, i] .= getindex.(sol.u, 1)
246+
num_processed += 1
247+
else
248+
@warn "GPU Trajectory $i succeeded but time points mismatch (Expected $(num_times_gpu), Got $(length(sol.t))). Skipping."
249+
# Column remains NaN
250+
end
251+
else
252+
# @warn "GPU Trajectory $i failed with retcode: $(sol.retcode). Skipping." # Potentially verbose
253+
# Column remains NaN
254+
end
255+
end
256+
@info "Processed $num_processed successful GPU trajectories."
107257
258+
# Filter out failed trajectories (columns with NaN)
259+
successful_traj_indices_gpu = findall(
260+
j -> !all(isnan, view(ensemble_vals_gpu, :, j)), 1:num_trajectories)
261+
num_successful_gpu = length(successful_traj_indices_gpu)
262+
263+
if num_successful_gpu == 0
264+
@error "No successful GPU trajectories suitable for analysis after processing!"
265+
else
266+
if num_successful_gpu < num_trajectories
267+
# This count includes those skipped due to time mismatch or failure
268+
@warn "Analysis based on $num_successful_gpu trajectories (out of $num_trajectories initial)."
269+
# Keep only successful, valid ones
270+
ensemble_vals_gpu = ensemble_vals_gpu[:, successful_traj_indices_gpu]
271+
end
272+
273+
# Compute ensemble statistics over successful GPU trajectories
274+
mean_u_gpu = mapslices(mean, ensemble_vals_gpu, dims = 2)[:]
275+
std_u_gpu = mapslices(std, ensemble_vals_gpu, dims = 2)[:]
276+
277+
# --- Plotting GPU Results ---
278+
p1_gpu = plot(t_vals_gpu, mean_u_gpu, ribbon = std_u_gpu,
279+
xlabel = "Time (t)", ylabel = "u(t)",
280+
title = "GPU Ensemble Mean ±1σ ($num_successful_gpu Trajectories)",
281+
label = "Mean u(t)", fillalpha = 0.3, lw = 2)
282+
283+
final_vals_gpu = ensemble_vals_gpu[end, :]
284+
p2_gpu = histogram(final_vals_gpu, bins = 30, normalize = :probability,
285+
xlabel = "Final u(T)", ylabel = "Probability Density",
286+
title = "GPU Distribution of Final Values (t=$(tspan[2]))",
287+
label = "", legend = false)
288+
289+
gpu_analysis_plot = plot(p1_gpu, p2_gpu, layout = (1, 2),
290+
size = (1000, 450), legend = :outertopright)
291+
@info "Displaying GPU analysis plot..."
292+
display(gpu_analysis_plot)
293+
294+
# Cleanup large structures if memory is a concern
295+
ensemble_vals_gpu = nothing
296+
solutions_vector = nothing
297+
# gc()
298+
end
299+
end
300+
end # End if solutions_vector !== nothing
301+
else
302+
@warn "Skipping GPU analysis section because initial GPU performance run failed or CUDA is unavailable."
303+
end
304+
```

src/algorithms.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ prob = ODEProblem{false}(lorenz, u0, tspan, p)
146146
prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p)
147147
monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false)
148148
149-
@time sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(CUDA.CUDABackend()), trajectories = 10_000,
149+
@time sol = solve(
150+
monteprob, GPUTsit5(), EnsembleGPUKernel(CUDA.CUDABackend()), trajectories = 10_000,
150151
adaptive = false, dt = 0.1f0)
151152
```
152153
"""

0 commit comments

Comments
 (0)