Skip to content

Commit 420eb7b

Browse files
committed
fix tests and cleanup
1 parent 441d905 commit 420eb7b

File tree

5 files changed

+16
-348
lines changed

5 files changed

+16
-348
lines changed

src/DynamicVehicleScheduling/anticipative_solver.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ function anticipative_solver(
9292
job_indices = 2:nb_nodes
9393
epoch_indices = T
9494

95-
@variable(model, y[i=1:nb_nodes, j=1:nb_nodes, t=epoch_indices]; binary=true)
95+
@variable(model, y[i = 1:nb_nodes, j = 1:nb_nodes, t = epoch_indices]; binary=true)
9696

9797
@objective(
9898
model,
9999
Max,
100100
sum(
101-
-duration[i, j] * y[i, j, t] for i in 1:nb_nodes, j in 1:nb_nodes,
102-
t in epoch_indices
101+
-duration[i, j] * y[i, j, t] for
102+
i in 1:nb_nodes, j in 1:nb_nodes, t in epoch_indices
103103
)
104104
)
105105

@@ -171,14 +171,12 @@ function anticipative_solver(
171171
routes = epoch_routes[i]
172172
epoch_customers = epoch_indices[i]
173173

174-
y_true =
175-
VSPSolution(
176-
Vector{Int}[
177-
map(idx -> findfirst(==(idx), epoch_customers), route) for
178-
route in routes
179-
];
180-
max_index=length(epoch_customers),
181-
).edge_matrix
174+
y_true = VSPSolution(
175+
Vector{Int}[
176+
map(idx -> findfirst(==(idx), epoch_customers), route) for route in routes
177+
];
178+
max_index=length(epoch_customers),
179+
).edge_matrix
182180

183181
location_indices = customer_index[epoch_customers]
184182
new_coordinates = env.instance.static_instance.coordinate[location_indices]
@@ -202,7 +200,8 @@ function anticipative_solver(
202200
is_must_dispatch[2:end] .= true
203201
else
204202
is_must_dispatch[2:end] .=
205-
planning_start_time .+ epoch_duration .+ @view(new_duration[1, 2:end]) .> new_start_time[2:end]
203+
planning_start_time .+ epoch_duration .+ @view(new_duration[1, 2:end]) .>
204+
new_start_time[2:end]
206205
end
207206
is_postponable[2:end] .= .!is_must_dispatch[2:end]
208207
# TODO: avoid code duplication with add_new_customers!

src/DynamicVehicleScheduling/plot.jl

Lines changed: 1 addition & 289 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,3 @@
1-
"""
2-
Typed container for plot-ready data.
3-
"""
4-
struct PlotData
5-
n_epochs::Int
6-
coordinates::Vector{Vector{Tuple{Float64,Float64}}}
7-
start_times::Vector{Vector{Float64}}
8-
node_types::Vector{Vector{Symbol}}
9-
routes::Vector{Vector{Vector{Int}}}
10-
epoch_costs::Vector{Float64}
11-
end
12-
13-
function PlotData(d::Dict)
14-
return PlotData(
15-
d[:n_epochs],
16-
d[:coordinates],
17-
d[:start_times],
18-
d[:node_types],
19-
d[:routes],
20-
d[:epoch_costs],
21-
)
22-
end
23-
241
function plot_instance(env::DVSPEnv; kwargs...)
252
return plot_instance(env.instance.static_instance; kwargs...)
263
end
@@ -461,7 +438,7 @@ function animate_epochs(
461438
xlims=xlims,
462439
ylims=ylims,
463440
clims=clims,
464-
title="Epoch $(state.current_epoch) - Available Requests",
441+
title="Epoch $(state.current_epoch) - Available Customers",
465442
titlefontsize=titlefontsize,
466443
guidefontsize=guidefontsize,
467444
legendfontsize=legendfontsize,
@@ -592,268 +569,3 @@ function animate_epochs(
592569

593570
return anim
594571
end
595-
596-
"""
597-
Animate multiple solutions where each solution is provided as its own vector of
598-
`DataSample` objects (one per epoch). This treats each solution's `DataSample`
599-
as the canonical source for that column in the side-by-side animation.
600-
"""
601-
function animate_solutions_side_by_side(
602-
solutions_data_samples::AbstractVector;
603-
solution_names=nothing,
604-
filename="dvsp_solutions_side_by_side.gif",
605-
fps=1,
606-
figsize=(1200, 600),
607-
margin=0.1,
608-
legend_margin_factor=0.15,
609-
titlefontsize=14,
610-
guidefontsize=12,
611-
legendfontsize=11,
612-
tickfontsize=10,
613-
show_axis_labels=true,
614-
show_cost_bar=true,
615-
cost_bar_width=0.05,
616-
cost_bar_margin=0.02,
617-
cost_bar_color_palette=:turbo,
618-
kwargs...,
619-
)
620-
n_solutions = length(solutions_data_samples)
621-
if n_solutions == 0
622-
error("No solutions provided")
623-
end
624-
625-
# Ensure all solution sequences have the same number of epochs
626-
n_epochs = length(solutions_data_samples[1])
627-
for (i, s) in enumerate(solutions_data_samples)
628-
if length(s) != n_epochs
629-
error(
630-
"All solution DataSample vectors must have the same length. Solution $i has length $(length(s)) but expected $n_epochs",
631-
)
632-
end
633-
end
634-
635-
if isnothing(solution_names)
636-
solution_names = ["Solution $(i)" for i in 1:n_solutions]
637-
end
638-
639-
# Collect global coordinates and start times across all solutions/epochs
640-
all_coordinates = []
641-
all_start_times = []
642-
epoch_costs_per_solution = [Float64[] for _ in 1:n_solutions]
643-
644-
for j in 1:n_solutions
645-
samples = solutions_data_samples[j]
646-
for (t, sample) in enumerate(samples)
647-
if !isnothing(sample.instance.state)
648-
append!(all_coordinates, coordinate(sample.instance.state))
649-
append!(all_start_times, start_time(sample.instance.state))
650-
651-
if sample.y_true isa BitMatrix
652-
routes = decode_bitmatrix_to_routes(sample.y_true)
653-
else
654-
routes = sample.y_true isa Vector{Int} ? [sample.y_true] : sample.y_true
655-
end
656-
c = isnothing(routes) ? 0.0 : cost(sample.instance.state, routes)
657-
push!(epoch_costs_per_solution[j], c)
658-
else
659-
push!(epoch_costs_per_solution[j], NaN)
660-
end
661-
end
662-
end
663-
664-
if isempty(all_coordinates)
665-
error("No valid coordinates found in solution data samples")
666-
end
667-
668-
# Global limits
669-
xlims = (
670-
minimum(p.x for p in all_coordinates) - margin,
671-
maximum(p.x for p in all_coordinates) + margin,
672-
)
673-
674-
y_min = minimum(p.y for p in all_coordinates) - margin
675-
y_max = maximum(p.y for p in all_coordinates) + margin
676-
y_range = y_max - y_min
677-
legend_margin = y_range * legend_margin_factor
678-
ylims = (y_min, y_max + legend_margin)
679-
680-
clims = if !isempty(all_start_times)
681-
(minimum(all_start_times), maximum(all_start_times))
682-
else
683-
(0.0, 1.0)
684-
end
685-
686-
# Robust cumulative costs per solution
687-
robust_cumulative = Vector{Vector{Float64}}(undef, n_solutions)
688-
for j in 1:n_solutions
689-
robust = Float64[]
690-
s = 0.0
691-
for c in epoch_costs_per_solution[j]
692-
if !isnan(c)
693-
s += c
694-
end
695-
push!(robust, s)
696-
end
697-
robust_cumulative[j] = robust
698-
end
699-
700-
function has_routes_local(routes)
701-
if isnothing(routes)
702-
return false
703-
elseif routes isa Vector{Vector{Int}}
704-
return any(!isempty(route) for route in routes)
705-
elseif routes isa Vector{Int}
706-
return !isempty(routes)
707-
elseif routes isa BitMatrix
708-
return any(routes)
709-
else
710-
return false
711-
end
712-
end
713-
714-
anim = @animate for t in 1:n_epochs
715-
col_plots = []
716-
for j in 1:n_solutions
717-
sample = solutions_data_samples[j][t]
718-
state = sample.instance.state
719-
routes = sample.y_true
720-
721-
if isnothing(state)
722-
fig = plot(;
723-
xlims=xlims,
724-
ylims=ylims,
725-
title="$(solution_names[j]) - Epoch $t (No Data)",
726-
titlefontsize=titlefontsize,
727-
guidefontsize=guidefontsize,
728-
tickfontsize=tickfontsize,
729-
legend=false,
730-
kwargs...,
731-
)
732-
else
733-
if has_routes_local(routes)
734-
fig = plot_routes(
735-
state,
736-
routes;
737-
xlims=xlims,
738-
ylims=ylims,
739-
clims=clims,
740-
title="$(solution_names[j]) - Epoch $(state.current_epoch)",
741-
titlefontsize=titlefontsize,
742-
guidefontsize=guidefontsize,
743-
legendfontsize=legendfontsize,
744-
tickfontsize=tickfontsize,
745-
show_axis_labels=show_axis_labels,
746-
markerstrokewidth=0.5,
747-
show_route_labels=false,
748-
show_colorbar=false,
749-
size=(floor(Int, figsize[1] / n_solutions), figsize[2]),
750-
kwargs...,
751-
)
752-
else
753-
fig = plot_state(
754-
state;
755-
xlims=xlims,
756-
ylims=ylims,
757-
clims=clims,
758-
title="$(solution_names[j]) - Epoch $(state.current_epoch)",
759-
titlefontsize=titlefontsize,
760-
guidefontsize=guidefontsize,
761-
legendfontsize=legendfontsize,
762-
tickfontsize=tickfontsize,
763-
show_axis_labels=show_axis_labels,
764-
show_colorbar=false,
765-
markerstrokewidth=0.5,
766-
size=(floor(Int, figsize[1] / n_solutions), figsize[2]),
767-
kwargs...,
768-
)
769-
end
770-
end
771-
772-
# cost bar
773-
if show_cost_bar
774-
current_cost = robust_cumulative[j][t]
775-
max_cost = maximum([robust_cumulative[k][end] for k in 1:n_solutions])
776-
777-
x_min, x_max = xlims
778-
x_range = x_max - x_min
779-
bar_x_start = x_max - cost_bar_width * x_range
780-
bar_x_end = x_max - cost_bar_margin * x_range
781-
782-
y_min, y_max = ylims
783-
y_range = y_max - y_min
784-
bar_y_start = y_min + 0.1 * y_range
785-
bar_y_end = y_max - 0.1 * y_range
786-
bar_height = bar_y_end - bar_y_start
787-
788-
if max_cost > 0
789-
filled_height = (current_cost / max_cost) * bar_height
790-
else
791-
filled_height = 0.0
792-
end
793-
794-
plot!(
795-
fig,
796-
[bar_x_start, bar_x_end, bar_x_end, bar_x_start, bar_x_start],
797-
[bar_y_start, bar_y_start, bar_y_end, bar_y_end, bar_y_start];
798-
seriestype=:shape,
799-
color=:white,
800-
alpha=0.8,
801-
linecolor=:black,
802-
linewidth=2,
803-
label="",
804-
)
805-
806-
if filled_height > 0
807-
cmap = Plots.cgrad(cost_bar_color_palette)
808-
ratio = max_cost > 0 ? current_cost / max_cost : 0.0
809-
color_at_val = Plots.get(cmap, ratio)
810-
plot!(
811-
fig,
812-
[bar_x_start, bar_x_end, bar_x_end, bar_x_start, bar_x_start],
813-
[
814-
bar_y_start,
815-
bar_y_start,
816-
bar_y_start + filled_height,
817-
bar_y_start + filled_height,
818-
bar_y_start,
819-
];
820-
seriestype=:shape,
821-
color=color_at_val,
822-
alpha=0.7,
823-
linecolor=:darkred,
824-
linewidth=1,
825-
label="",
826-
)
827-
end
828-
829-
cost_text_y = bar_y_start + filled_height + 0.02 * y_range
830-
if cost_text_y > bar_y_end
831-
cost_text_y = bar_y_end
832-
end
833-
plot!(
834-
fig,
835-
[bar_x_start + (bar_x_end - bar_x_start) / 2],
836-
[cost_text_y];
837-
seriestype=:scatter,
838-
markersize=0,
839-
label="",
840-
annotations=(
841-
bar_x_start - 0.04 * x_range,
842-
cost_text_y,
843-
(@sprintf("%.1f", current_cost), :center, guidefontsize),
844-
),
845-
)
846-
end
847-
848-
push!(col_plots, fig)
849-
end
850-
851-
combined = plot(
852-
col_plots...; layout=(1, n_solutions), size=figsize, link=:both, clims=clims
853-
)
854-
combined
855-
end
856-
857-
gif(anim, filename; fps=fps)
858-
return anim
859-
end

test/dynamic_vsp_plots.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
env = environments[1]
1111

1212
# Test basic plotting functions
13-
fig1 = DVSP.plot_instancee(env)
13+
fig1 = DVSP.plot_instance(env)
1414
@test fig1 isa Plots.Plot
1515

1616
instance = dataset[1].instance
@@ -23,7 +23,7 @@
2323
policies = generate_policies(b)
2424
lazy = policies[1]
2525
_, d = evaluate_policy!(lazy, env)
26-
fig3 = DVSP.plot_routes(d[1].instance, d[1].y_true)
26+
fig3 = DVSP.plot_routes(d[1].instance.state, d[1].y_true)
2727
@test fig3 isa Plots.Plot
2828

2929
# Test animation

0 commit comments

Comments
 (0)