@@ -101,6 +101,9 @@ def generate_ckpt(
101101 pre_checkpoint_batches = batches [:steps_before_checkpoint ]
102102 post_checkpoint_batches = batches [steps_before_checkpoint :]
103103
104+ # Compute metrics for post-checkpoint batches only
105+ post_checkpoint_metrics = all_metrics [len (checkpoint_metrics ) :]
106+
104107 # Resume with new instance if provided
105108 resumed_batches = []
106109 resumed_metrics = []
@@ -127,24 +130,28 @@ def generate_ckpt(
127130 # Original run
128131 "pre_checkpoint_batches" : pre_checkpoint_batches ,
129132 "post_checkpoint_batches" : post_checkpoint_batches ,
130- "metrics_at_checkpoint" : keep_last_metric (checkpoint_metrics ),
131- "final_metrics" : keep_last_metric (all_metrics ),
133+ "metrics_at_checkpoint" : aggregate_metrics (checkpoint_metrics ),
134+ "post_checkpoint_metrics" : aggregate_metrics (post_checkpoint_metrics ),
135+ "final_metrics" : aggregate_metrics (all_metrics ),
132136 # Resumed run
133137 "resumed_batches" : resumed_batches ,
134- "resumed_metrics" : keep_last_metric (resumed_metrics ),
138+ "resumed_metrics" : aggregate_metrics (resumed_metrics ),
135139 # Internal state for loading - only if someone needs to manually load
136140 "_checkpoint_state" : checkpoint_state ,
137141 }
138142
139143
140- def keep_last_metric (metrics_list : list ) -> dict [str , Any ]:
141- result = {}
144+ def aggregate_metrics (metrics_list : list ) -> dict [str , Any ]:
145+ """Aggregate metrics according to their reduction types (SUM, MEAN, MAX, MIN, STD)."""
146+ if not metrics_list :
147+ return {}
148+
149+ accumulators = {}
150+
142151 for metric in metrics_list :
143- # Expect observability.Metric objects only
144152 key = metric .key
145- value = metric .value
146-
147- # For test purposes, just keep the last value of each metric
148- result [key ] = value
153+ if key not in accumulators :
154+ accumulators [key ] = metric .reduction .accumulator_class (metric .reduction )
155+ accumulators [key ].append (metric .value )
149156
150- return result
157+ return { key : acc . get_value () for key , acc in accumulators . items ()}
0 commit comments