Skip to content

Commit e7e42b9

Browse files
committed
fix error + some debug messages
1 parent 9234389 commit e7e42b9

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

src/forge/observability/metric_actors.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
LoggerBackend,
1919
LoggingMode,
2020
MetricCollector,
21+
Reduce,
2122
reduce_metrics_states,
2223
)
2324

@@ -432,6 +433,7 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]:
432433
# Reduce metrics from states
433434
reduced_metrics = reduce_metrics_states(all_local_states)
434435

436+
print(f"[DEBUG] reduced_metrics: {reduced_metrics}")
435437
# Split into scalar metrics and sample metrics
436438
scalar_metrics = [
437439
m for m in reduced_metrics if m.reduction != Reduce.SAMPLE
@@ -443,6 +445,7 @@ def extract_values_from_valuemesh(results) -> list[dict[str, Any]]:
443445
# Log to global backends
444446
for backend_name, backend in self.global_logger_backends.items():
445447
if scalar_metrics:
448+
print(f"[DEBUG] calling log_batch from GlobalLoggerActor")
446449
await backend.log_batch(scalar_metrics, global_step)
447450
if sample_metrics:
448451
await backend.log_samples(sample_metrics, global_step)

src/forge/observability/metrics.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,15 @@ def record_episode_sample(table_name: str, episode):
228228
"pad_id": episode.pad_id,
229229
}
230230

231+
print(
232+
"[DEBUG] Adding sample to table via record_metric, episode_id: ",
233+
episode.episode_id,
234+
)
231235
record_metric(table_name, sample, Reduce.SAMPLE)
236+
print(
237+
"[DEBUG] Added sample to table via record_metric, episode_id: ",
238+
episode.episode_id,
239+
)
232240

233241

234242
#################
@@ -499,19 +507,22 @@ def __init__(self, reduction: Reduce):
499507
super().__init__(reduction)
500508
self.samples: List[Dict[str, Any]] = []
501509
self.filter = TopBottomKFilter()
510+
self.is_reset = True
502511

503512
def append(self, value: dict) -> None:
504513
if not isinstance(value, dict):
505514
raise ValueError(f"Expected dict, got {type(value)}")
506515

516+
self.is_reset = False
507517
# Only keep the sample if filter_append returns True
508518
if self.filter.filter_append(value):
509519
self.samples.append(value)
510520

511521
def get_value(self) -> list[dict]:
512522
"""Return locally collected (and optionally filtered) samples."""
513523
# Apply flush-time filter (e.g. heap selection, threshold trimming)
514-
return self.filter.filter_flush(self.samples)
524+
results = self.filter.filter_flush(self.samples)
525+
return results
515526

516527
def get_state(self) -> Dict[str, Any]:
517528
"""Serialize accumulator state for cross-rank reduction."""
@@ -530,6 +541,7 @@ def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> list[dic
530541

531542
def reset(self) -> None:
532543
"""Clear local samples and reset filter state."""
544+
self.is_reset = True
533545
self.samples.clear()
534546
self.filter.reset()
535547

@@ -701,12 +713,12 @@ def push(self, metric: Metric) -> None:
701713

702714
# For PER_RANK_NO_REDUCE backends: stream without reduce
703715
for backend in self.per_rank_no_reduce_backends:
704-
if metric.reduction == Reduce.SAMPLE:
705-
# Wrap singleton Metric into expected {key: [list_of_dicts]} format
706-
sample = {metric.key: [metric.value]}
707-
asyncio.create_task(backend.log_samples(sample, self.global_step))
708-
else:
709-
backend.log_stream(metric=metric, global_step=self.global_step)
716+
# if metric.reduction == Reduce.SAMPLE:
717+
# # Wrap singleton Metric into expected {key: [list_of_dicts]} format
718+
# sample = {metric.key: [metric.value]}
719+
# asyncio.create_task(backend.log_samples(sample, self.global_step))
720+
# else:
721+
backend.log_stream(metric=metric, global_step=self.global_step)
710722

711723
# Always accumulate for reduction and state return
712724
key = metric.key
@@ -773,6 +785,7 @@ async def flush(
773785

774786
for backend in self.per_rank_reduce_backends:
775787
if scalar_metrics:
788+
print(f"[DEBUG] calling log_batch from MetricCollector")
776789
await backend.log_batch(scalar_metrics, global_step)
777790
if sample_metrics:
778791
await backend.log_samples(sample_metrics, global_step)
@@ -895,6 +908,7 @@ async def init(
895908
async def log_batch(
896909
self, metrics: list[Metric], global_step: int, *args, **kwargs
897910
) -> None:
911+
print(f"[DEBUG] calling log_batch with {len(metrics)} metrics")
898912
metrics_str = "\n".join(
899913
f" {metric.key}: {metric.value}"
900914
for metric in sorted(metrics, key=lambda m: m.key)
@@ -913,6 +927,8 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
913927
"""Pretty-print sample-level logs to console."""
914928
import json
915929

930+
print(f"[DEBUG] calling log_samples with {len(samples)} samples")
931+
916932
logger.info(f"========== SAMPLE LOGS STEP {step} ==========")
917933
for table_name, table_rows in samples.items():
918934
logger.info(f"[{table_name}] ({len(table_rows)} samples)")

0 commit comments

Comments
 (0)