Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def bootstrap(env: dict[str, str]):

self._proc_host_map[procs] = host_mesh

# Spawn local logging actor on each process and register with global logger
# Spawn local fetcher actor on each process and register with global logger
# Can be disabled by FORGE_DISABLE_METRICS env var
_ = await get_or_create_metric_logger(procs)
return procs

Expand Down
1 change: 1 addition & 0 deletions src/forge/env_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
METRIC_TIMER_USES_CUDA = "METRIC_TIMER_USES_CUDA"

# Makes forge.observability.metrics.record_metric a no-op
# and disables spawning LocalFetcherActor in get_or_create_metric_logger
FORGE_DISABLE_METRICS = "FORGE_DISABLE_METRICS"
9 changes: 7 additions & 2 deletions src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

import asyncio
import logging
import os
from typing import Any, Dict, Optional

from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc

from forge.env_constants import FORGE_DISABLE_METRICS
from forge.observability.metrics import (
get_logger_backend_class,
LoggerBackend,
Expand Down Expand Up @@ -95,8 +97,11 @@ async def get_or_create_metric_logger(
f"Both should be True (already setup) or both False (needs setup)."
)

# Setup local_fetcher_actor if needed
if not proc_has_local_fetcher:
# Setup local_fetcher_actor if needed (unless disabled by environment flag)
if (
not proc_has_local_fetcher
and os.getenv(FORGE_DISABLE_METRICS, "false").lower() != "true"
):
local_fetcher_actor = proc.spawn(
"local_fetcher_actor", LocalFetcherActor, global_logger
)
Expand Down
37 changes: 34 additions & 3 deletions src/forge/observability/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,21 @@ async def init_backends(

def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None:
if not self._is_initialized:
raise ValueError("Collector not initialized—call init first")
from forge.util.logging import log_once

log_once(
logger,
level=logging.WARNING,
msg=(
"Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized."
" This happens when you try to use `record_metric` before calling `init_backends`."
" To disable this warning, please call in your main file:\n"
"`mlogger = await get_or_create_metric_logger()`\n"
"`await mlogger.init_backends.call_one(logging_config)`\n"
"or set env variable `FORGE_DISABLE_METRICS=True`"
),
)
return

if key not in self.accumulators:
self.accumulators[key] = reduction.accumulator_class(reduction)
Expand All @@ -458,8 +472,16 @@ async def flush(
e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}.
"""
if not self._is_initialized:
logger.debug(
f"Collector not yet initialized for {get_actor_name_with_rank()}. Call init_backends first."
from forge.util.logging import log_once

log_once(
logger,
level=logging.WARNING,
msg="Cannot flush collected metrics. MetricCollector.flush() called before init_backends()."
"\nPlease call in your main file:\n"
"`mlogger = await get_or_create_metric_logger()`\n"
"`await mlogger.init_backends.call_one(logging_config)`\n"
"before calling `flush`",
)
return {}

Expand Down Expand Up @@ -662,6 +684,15 @@ async def _init_shared_local(self, primary_metadata: Dict[str, Any]):
raise ValueError(
f"Shared ID required but not provided for {self.name} backend init"
)

# Clear any stale service tokens that might be pointing to dead processes
# In multiprocessing environments, WandB service tokens can become stale and point
# to dead service processes. This causes wandb.init() to hang indefinitely trying
# to connect to non-existent services. Clearing forces fresh service connection.
from wandb.sdk.lib.service import service_token
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this also be imported at the top?

Copy link
Contributor Author

@felipemello1 felipemello1 Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think so. My opinion is that backends like wandb should be protected, otherwise user is required to have wandb even if they dont use it. Extrapolate that to mlflow, scuba, etc. Let me know if you disagree.


service_token.clear_service_in_env()

settings = wandb.Settings(mode="shared", x_primary=False, x_label=self.name)
self.run = wandb.init(
id=shared_id,
Expand Down
95 changes: 95 additions & 0 deletions tests/unit_tests/observability/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Shared fixtures and mocks for observability unit tests."""

from unittest.mock import MagicMock, patch

import pytest
from forge.observability.metrics import LoggerBackend, MetricCollector


class MockBackend(LoggerBackend):
"""Mock backend for testing metrics logging without external dependencies."""

def __init__(self, logger_backend_config=None):
super().__init__(logger_backend_config or {})
self.logged_metrics = []
self.init_called = False
self.finish_called = False
self.metadata = {}

async def init(self, role="local", primary_logger_metadata=None):
self.init_called = True
self.role = role
self.primary_logger_metadata = primary_logger_metadata or {}

async def log(self, metrics, step):
self.logged_metrics.append((metrics, step))

async def finish(self):
self.finish_called = True

def get_metadata_for_secondary_ranks(self):
return self.metadata


@pytest.fixture(autouse=True)
def clear_metric_collector_singletons():
"""Clear MetricCollector singletons before each test to avoid state leakage."""
MetricCollector._instances.clear()
yield
MetricCollector._instances.clear()


@pytest.fixture(autouse=True)
def clean_metrics_environment():
"""Override the global mock_metrics_globally fixture to allow real metrics testing."""
import os

from forge.env_constants import FORGE_DISABLE_METRICS

# Set default state for tests (metrics enabled)
if FORGE_DISABLE_METRICS in os.environ:
del os.environ[FORGE_DISABLE_METRICS]

yield


@pytest.fixture
def mock_rank():
"""Mock current_rank function with configurable rank."""
with patch("forge.observability.metrics.current_rank") as mock:
rank_obj = MagicMock()
rank_obj.rank = 0
mock.return_value = rank_obj
yield mock


@pytest.fixture
def mock_actor_context():
"""Mock Monarch actor context for testing actor name generation."""
with patch("forge.observability.metrics.context") as mock_context, patch(
"forge.observability.metrics.current_rank"
) as mock_rank:

# Setup mock context
ctx = MagicMock()
actor_instance = MagicMock()
actor_instance.actor_id = "_1rjutFUXQrEJ[0].TestActorConfigured[0]"
ctx.actor_instance = actor_instance
mock_context.return_value = ctx

# Setup mock rank
rank_obj = MagicMock()
rank_obj.rank = 0
mock_rank.return_value = rank_obj

yield {
"context": mock_context,
"rank": mock_rank,
"expected_name": "TestActor_0XQr_r0",
}
Loading
Loading