@@ -396,6 +396,7 @@ function animate_epochs(
396396 show_colorbar= false ,
397397 cost_bar_width= 0.05 ,
398398 cost_bar_margin= 0.02 ,
399+ cost_bar_color_palette= :turbo ,
399400 kwargs... ,
400401)
401402 n_epochs = length (data_samples)
@@ -604,7 +605,7 @@ function animate_epochs(
604605 label= " " ,
605606 )
606607
607- cmap = Plots. cgrad (:turbo ) # or :plasma, :inferno, etc.
608+ cmap = Plots. cgrad (cost_bar_color_palette)
608609 # Draw the filled portion with solid color
609610 if filled_height > 0
610611 # Get a color at a value between 0 and 1
@@ -629,39 +630,10 @@ function animate_epochs(
629630 )
630631 end
631632
632- # Add cost labels
633- # plot!(
634- # fig,
635- # [bar_x_start - 0.01 * x_range],
636- # [bar_y_start];
637- # seriestype=:scatter,
638- # markersize=0,
639- # label="",
640- # annotations=(
641- # bar_x_start - 0.02 * x_range, bar_y_start, ("0", :right, guidefontsize)
642- # ),
643- # )
644-
645- # if max_cost > 0
646- # plot!(
647- # fig,
648- # [bar_x_start - 0.01 * x_range],
649- # [bar_y_end];
650- # seriestype=:scatter,
651- # markersize=0,
652- # label="",
653- # annotations=(
654- # bar_x_start - 0.02 * x_range,
655- # bar_y_end,
656- # (@sprintf("%.1f", max_cost), :right, guidefontsize),
657- # ),
658- # )
659- # end
660-
661633 # Add current cost value
662634 cost_text_y = bar_y_start + filled_height + 0.02 * y_range
663635 if cost_text_y > bar_y_end
664- cost_text_y = bar_y_end # + 0.01 * y_range
636+ cost_text_y = bar_y_end
665637 end
666638
667639 plot! (
@@ -672,7 +644,7 @@ function animate_epochs(
672644 markersize= 0 ,
673645 label= " " ,
674646 annotations= (
675- bar_x_start - 0.04 * x_range,# (bar_x_start + bar_x_end) / 2,
647+ bar_x_start - 0.04 * x_range,
676648 cost_text_y,
677649 (@sprintf (" %.1f" , current_cost), :center , guidefontsize),
678650 ),
0 commit comments