Skip to content

Commit 441d905

Browse files
committed
cleanup animations
1 parent 599d771 commit 441d905

File tree

1 file changed

+59
-112
lines changed
  • src/DynamicVehicleScheduling

1 file changed

+59
-112
lines changed

src/DynamicVehicleScheduling/plot.jl

Lines changed: 59 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,13 @@ The returned dictionary contains:
231231
This lets plotting code build figures without depending on plotting internals.
232232
"""
233233
function build_plot_data(data_samples::Vector{<:DataSample})
234-
return [build_state_data(sample.instance.state) for sample in data_samples]
234+
state_data = [build_state_data(sample.instance.state) for sample in data_samples]
235+
rewards = [sample.instance.reward for sample in data_samples]
236+
routess = [sample.y_true for sample in data_samples]
237+
return [
238+
(; state..., reward, routes) for
239+
(state, reward, routes) in zip(state_data, rewards, routess)
240+
]
235241
end
236242

237243
"""
@@ -385,34 +391,24 @@ function animate_epochs(
385391
end
386392

387393
pd = build_plot_data(data_samples)
388-
n_epochs = pd.n_epochs
389-
# Build all_coordinates from pd.coordinates
390-
all_coordinates = []
391-
for coords in pd.coordinates
392-
for c in coords
393-
push!(all_coordinates, (; x=c[1], y=c[2]))
394-
end
395-
end
396-
epoch_costs = copy(pd.epoch_costs)
397-
398-
if isempty(all_coordinates)
399-
error("No valid coordinates found in data samples")
400-
end
394+
epoch_costs = [-sample.instance.reward for sample in data_samples]
401395

402-
# Calculate cumulative costs for the cost bar
403-
cumulative_costs = cumsum(epoch_costs)
404-
max_cost = isempty(cumulative_costs) ? 1.0 : maximum(cumulative_costs)
405-
406-
xlims = (
407-
minimum(p.x for p in all_coordinates) - margin,
408-
maximum(p.x for p in all_coordinates) + margin,
409-
)
396+
# Calculate global xlims and ylims from all states
397+
x_min = minimum(min(data.x_depot, minimum(data.x_customers)) for data in pd)
398+
x_max = maximum(max(data.x_depot, maximum(data.x_customers)) for data in pd)
399+
y_min = minimum(min(data.y_depot, minimum(data.y_customers)) for data in pd)
400+
y_max = maximum(max(data.y_depot, maximum(data.y_customers)) for data in pd)
410401

411-
# Add extra margin at the top for legend space and cost bar
412-
y_min = minimum(p.y for p in all_coordinates) - margin
413-
y_max = maximum(p.y for p in all_coordinates) + margin
414-
y_range = y_max - y_min
402+
xlims = (x_min - margin, x_max + margin)
403+
# Add extra margin at the top for legend space
404+
y_range = y_max - y_min + 2 * margin
415405
legend_margin = y_range * legend_margin_factor
406+
ylims = (y_min - margin, y_max + margin + legend_margin)
407+
408+
# Calculate global color limits for consistent scaling across subplots
409+
min_start_time = minimum(minimum(data.start_times) for data in pd)
410+
max_start_time = maximum(maximum(data.start_times) for data in pd)
411+
clims = (min_start_time, max_start_time)
416412

417413
# Adjust x-axis if showing cost bar
418414
if show_cost_bar
@@ -422,48 +418,13 @@ function animate_epochs(
422418
xlims = (x_min, x_max + cost_bar_space)
423419
end
424420

425-
ylims = (y_min, y_max + legend_margin)
426-
427-
# Calculate global color limits
428-
all_start_times = []
429-
for sample in data_samples
430-
if !isnothing(sample.instance.state)
431-
times = start_time(sample.instance.state)
432-
append!(all_start_times, times)
433-
end
434-
end
435-
436-
clims = if !isempty(all_start_times)
437-
(minimum(all_start_times), maximum(all_start_times))
438-
else
439-
(0.0, 1.0)
440-
end
441-
442-
# Helper function to check if routes exist and are non-empty
443-
function has_routes(routes)
444-
if isnothing(routes)
445-
return false
446-
elseif routes isa Vector{Vector{Int}}
447-
return any(!isempty(route) for route in routes)
448-
elseif routes isa Vector{Int}
449-
return !isempty(routes)
450-
elseif routes isa BitMatrix
451-
return any(routes)
452-
else
453-
return false
454-
end
455-
end
456-
457-
# Create frame plan: determine which epochs have routes
421+
# Create interleaved frame plan: always include a state frame and a routes frame
422+
# for every epoch. The routes-frame will render a 'no routes' message when
423+
# no routes are present, which keeps timing consistent and the code simpler.
458424
frame_plan = []
459-
for (epoch_idx, sample) in enumerate(data_samples)
460-
# Always add state frame
425+
for (epoch_idx, _) in enumerate(data_samples)
461426
push!(frame_plan, (epoch_idx, :state))
462-
463-
# Add routes frame only if routes exist
464-
if has_routes(sample.y_true)
465-
push!(frame_plan, (epoch_idx, :routes))
466-
end
427+
push!(frame_plan, (epoch_idx, :routes))
467428
end
468429

469430
total_frames = length(frame_plan)
@@ -474,63 +435,47 @@ function animate_epochs(
474435
sample = data_samples[epoch_idx]
475436
state = sample.instance.state
476437

477-
if isnothing(state)
478-
# Empty frame for missing data
479-
fig = plot(;
438+
if frame_type == :routes
439+
fig = plot_routes(
440+
state,
441+
sample.y_true;
480442
xlims=xlims,
481443
ylims=ylims,
482-
title="Epoch $epoch_idx (No Data)",
444+
clims=clims,
445+
title="Epoch $(state.current_epoch) - Routes Dispatched",
483446
titlefontsize=titlefontsize,
484447
guidefontsize=guidefontsize,
448+
legendfontsize=legendfontsize,
485449
tickfontsize=tickfontsize,
486-
legend=false,
450+
show_axis_labels=show_axis_labels,
451+
markerstrokewidth=0.5,
452+
show_route_labels=false,
453+
show_colorbar=show_colorbar,
454+
size=figsize,
455+
kwargs...,
456+
)
457+
else # frame_type == :state
458+
# Show state only
459+
fig = plot_state(
460+
state;
461+
xlims=xlims,
462+
ylims=ylims,
463+
clims=clims,
464+
title="Epoch $(state.current_epoch) - Available Requests",
465+
titlefontsize=titlefontsize,
466+
guidefontsize=guidefontsize,
467+
legendfontsize=legendfontsize,
468+
tickfontsize=tickfontsize,
469+
show_axis_labels=show_axis_labels,
470+
markerstrokewidth=0.5,
471+
show_colorbar=show_colorbar,
487472
size=figsize,
488473
kwargs...,
489474
)
490-
else
491-
if frame_type == :routes
492-
# Show state with routes
493-
fig = plot_routes(
494-
state,
495-
sample.y_true;
496-
xlims=xlims,
497-
ylims=ylims,
498-
clims=clims,
499-
title="Epoch $(state.current_epoch) - Routes Dispatched",
500-
titlefontsize=titlefontsize,
501-
guidefontsize=guidefontsize,
502-
legendfontsize=legendfontsize,
503-
tickfontsize=tickfontsize,
504-
show_axis_labels=show_axis_labels,
505-
markerstrokewidth=0.5,
506-
show_route_labels=false,
507-
show_colorbar=show_colorbar,
508-
size=figsize,
509-
kwargs...,
510-
)
511-
else # frame_type == :state
512-
# Show state only
513-
fig = plot_state(
514-
state;
515-
xlims=xlims,
516-
ylims=ylims,
517-
clims=clims,
518-
title="Epoch $(state.current_epoch) - Available Requests",
519-
titlefontsize=titlefontsize,
520-
guidefontsize=guidefontsize,
521-
legendfontsize=legendfontsize,
522-
tickfontsize=tickfontsize,
523-
show_axis_labels=show_axis_labels,
524-
markerstrokewidth=0.5,
525-
show_colorbar=show_colorbar,
526-
size=figsize,
527-
kwargs...,
528-
)
529-
end
530475
end
531476

532477
# Add cost bar if requested
533-
if show_cost_bar && !isempty(cumulative_costs)
478+
if show_cost_bar
534479
# Calculate cost bar position on the right side of the plot
535480
x_min, x_max = xlims
536481
x_range = x_max - x_min
@@ -558,6 +503,7 @@ function animate_epochs(
558503
end
559504

560505
# Calculate filled height
506+
max_cost = sum(epoch_costs)
561507
if max_cost > 0
562508
filled_height = (current_cost / max_cost) * bar_height
563509
else
@@ -815,6 +761,7 @@ function animate_solutions_side_by_side(
815761
legendfontsize=legendfontsize,
816762
tickfontsize=tickfontsize,
817763
show_axis_labels=show_axis_labels,
764+
show_colorbar=false,
818765
markerstrokewidth=0.5,
819766
size=(floor(Int, figsize[1] / n_solutions), figsize[2]),
820767
kwargs...,

0 commit comments

Comments
 (0)