From fffcb88e97bbd38ac4d7c713026190593f6b1a01 Mon Sep 17 00:00:00 2001 From: Allen Wang <9057208+allenwang28@users.noreply.github.com> Date: Wed, 15 Oct 2025 15:42:22 -0700 Subject: [PATCH] Revert "Metric Logging updates 4/N - better actor name (#351)" This reverts commit 1f45470b0294c50cee539fd110722173c6714872. --- apps/grpo/main.py | 2 +- src/forge/controller/provisioner.py | 2 +- src/forge/observability/__init__.py | 6 +- src/forge/observability/metric_actors.py | 50 +----- src/forge/observability/metrics.py | 99 +++++++---- src/forge/observability/utils.py | 54 ------ tests/sandbox/toy_rl/toy_metrics/main.py | 10 +- tests/sandbox/vllm/main.py | 2 +- tests/unit_tests/observability/__init__.py | 5 - tests/unit_tests/observability/conftest.py | 7 +- .../observability/test_metric_actors.py | 162 ------------------ .../unit_tests/observability/test_metrics.py | 20 ++- tests/unit_tests/observability/test_utils.py | 60 ------- 13 files changed, 103 insertions(+), 376 deletions(-) delete mode 100644 src/forge/observability/utils.py delete mode 100644 tests/unit_tests/observability/__init__.py delete mode 100644 tests/unit_tests/observability/test_metric_actors.py delete mode 100644 tests/unit_tests/observability/test_utils.py diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 5a6576d7e..1dbef0b76 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -305,7 +305,7 @@ async def main(cfg: DictConfig): provisioner = await init_provisioner() metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) - mlogger = await get_or_create_metric_logger(process_name="Controller") + mlogger = await get_or_create_metric_logger() await mlogger.init_backends.call_one(metric_logging_cfg) # ---- Setup services ---- # diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 6d470a87f..c23d5fdd0 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -305,7 +305,7 @@ def bootstrap(env: dict[str, str]): if not FORGE_DISABLE_METRICS.get_value(): from forge.observability.metric_actors import get_or_create_metric_logger - _ = await get_or_create_metric_logger(procs, process_name=mesh_name) + _ = await get_or_create_metric_logger(procs) return procs async def host_mesh_from_proc(self, proc_mesh: ProcMesh): diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 8efd3dace..b970e57fa 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -12,6 +12,8 @@ from .metrics import ( BackendRole, ConsoleBackend, + get_actor_name_with_rank, + get_logger_backend_class, LoggerBackend, MaxAccumulator, MeanAccumulator, @@ -27,12 +29,12 @@ WandbBackend, ) from .perf_tracker import trace, Tracer -from .utils import get_proc_name_with_rank __all__ = [ # Main API functions "record_metric", "reduce_metrics_states", + "get_actor_name_with_rank", "get_logger_backend_class", "get_or_create_metric_logger", # Performance tracking @@ -43,8 +45,6 @@ "BackendRole", # Enums "Reduce", - # Utility functions - "get_proc_name_with_rank", # Actor classes "GlobalLoggingActor", "LocalFetcherActor", diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index f053d6a56..83ddd349e 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -8,14 +8,7 @@ import logging from typing import Any, Union -from monarch.actor import ( - Actor, - context, - endpoint, - get_or_spawn_controller, - ProcMesh, - this_proc, -) +from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc from forge.env import FORGE_DISABLE_METRICS from forge.observability.metrics import ( @@ -34,7 +27,6 @@ async def get_or_create_metric_logger( proc_mesh: ProcMesh | None = None, - process_name: str | None = None, ) -> "GlobalLoggingActor": """Initializes a LocalFetcherActor in the specified process mesh (or current process if None), if not already initialized, registers it with the GlobalLoggingActor and returns the @@ -48,9 +40,6 @@ async def get_or_create_metric_logger( Args: proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, uses `monarch.actor.this_proc()`. - process_name: Optional process name (e.g., "TrainActor", "GeneratorActor") for logging. - If None, will be auto-detected from the mesh_name provided during actor initialization or - a generic mesh name if one was not provided. Returns: GlobalLoggingActor: The global logging controller. @@ -64,7 +53,7 @@ async def get_or_create_metric_logger( from forge.observability.metrics import record_metric # Main process setup - mlogger = await get_or_create_metric_logger(process_name="Controller") + mlogger = await get_or_create_metric_logger() # Initialize logging backends await mlogger.init_backends({ @@ -77,14 +66,13 @@ async def get_or_create_metric_logger( # Training loop for step in range(max_steps): - record_metric("loss", 1.2, reduction_type=Reduce.MEAN) + record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN) # ... training code with record_metric() calls ... await mlogger.flush(step) # Log metrics for this step # Shutdown await mlogger.shutdown() """ - # Get or create the singleton global logger global _global_logger if _global_logger is None: @@ -96,11 +84,6 @@ async def get_or_create_metric_logger( # Determine process context proc = proc_mesh if proc_mesh is not None else this_proc() - # Auto-detect process_name from proc mesh if not provided - if process_name is None: - ctx = context() - process_name = ctx.actor_instance.actor_id.actor_name - # Check current state for consistency proc_has_local_fetcher = hasattr(proc, "_local_fetcher") global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc) @@ -108,7 +91,7 @@ async def get_or_create_metric_logger( # Consistency check: both should be in sync if proc_has_local_fetcher != global_logger_has_local_fetcher: raise ValueError( - f"Inconsistent logging state for {proc=} with {process_name=}: " + f"Inconsistent logging state for proc {proc}: " f"proc has _local_fetcher={proc_has_local_fetcher}, " f"but global_logger has registration={global_logger_has_local_fetcher}. " f"This indicates a bug in logging setup/teardown. " @@ -118,7 +101,7 @@ async def get_or_create_metric_logger( # Setup local_fetcher_actor if needed (unless disabled by environment flag) if not proc_has_local_fetcher and not FORGE_DISABLE_METRICS.get_value(): local_fetcher_actor = proc.spawn( - "local_fetcher_actor", LocalFetcherActor, global_logger, process_name + "local_fetcher_actor", LocalFetcherActor, global_logger ) await global_logger.register_fetcher.call_one(local_fetcher_actor, proc) proc._local_fetcher = local_fetcher_actor # pyre-ignore @@ -134,13 +117,8 @@ class LocalFetcherActor(Actor): GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector """ - def __init__( - self, - global_logger: Union["GlobalLoggingActor", None] = None, - process_name: str | None = None, - ) -> None: + def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None: self.global_logger = global_logger - self.process_name = process_name # Passed to MetricCollector for logging _is_initialized = False @endpoint @@ -167,22 +145,10 @@ async def init_backends( self, metadata_per_primary_backend: dict[str, dict[str, Any]], config: dict[str, Any], - global_step: int = 0, ) -> None: - """Init local (per-rank) logger backends and MetricCollector. - - Args: - metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state. - config (dict[str, Any]): Backend configurations with logging modes and settings. - global_step (int): Initial step for metrics. - """ + """Init local (per-rank) logger backends and MetricCollector.""" collector = MetricCollector() - await collector.init_backends( - metadata_per_primary_backend, - config, - global_step, - process_name=self.process_name, - ) + await collector.init_backends(metadata_per_primary_backend, config) @endpoint async def shutdown(self) -> None: diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 4996b3a7f..3ce849ad2 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -13,9 +13,8 @@ from typing import Any import pytz -from monarch.actor import current_rank +from monarch.actor import context, current_rank -from forge.observability.utils import get_proc_name_with_rank from forge.util.logging import log_once logger = logging.getLogger(__name__) @@ -55,7 +54,7 @@ def accumulator_class(self): class Metric: """Container for metric data including key, value, reduction type, and timestamp. - Timestamp is automatically set to current UTC time if not provided. + Timestamp is automatically set to current EST time if not provided. """ key: str @@ -69,6 +68,55 @@ def __post_init__(self): self.timestamp = datetime.now(pytz.UTC).timestamp() +def get_actor_name_with_rank() -> str: + """ + Extracts actor information from Monarch context to form a logging name. + + Returns: + str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0"). + Falls back to "UnknownActor" if context unavailable. + """ + # Add more defensive checks + ctx = context() + if ctx is None or ctx.actor_instance is None: + logger.warning("Context unavailable, using fallback actor name for logging.") + return "UnknownActor" + + actor_instance = ctx.actor_instance + rank = current_rank() + + actor_id_full = str(actor_instance.actor_id) + + # Parse the actor_id + parts = actor_id_full.split(".") + rank_name = "UnknownActor" # fallback + if len(parts) >= 2: + world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" + actor_part = parts[1] # e.g., "TestActorConfigured[0]" + + # Extract world ID and proc rank + world_id = world_part.split("[")[0] if "[" in world_part else world_part + + # Extract clean actor name (remove "Configured" suffix if present) + if "[" in actor_part: + actor_name = actor_part.split("[")[0] # e.g., "TestActorConfigured" + if actor_name.endswith("Configured"): + actor_name = actor_name[:-10] # Remove "Configured" + else: + actor_name = actor_part + + # Use last 4 characters of world_id as replica identifier + # This is deterministic, readable, and works for any number of replicas + replica_id = world_id[-4:] if len(world_id) >= 4 else world_id + + # Use current_rank().rank as the local rank within the replica + local_rank = rank.rank + + rank_name = f"{actor_name}_{replica_id}_r{local_rank}" + + return rank_name + + def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: """Thin wrapper to send metrics to per-rank local MetricCollectors. @@ -102,11 +150,11 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri states is more precise than merging locally reduced metrics. Args: - states (list[dict[str, dict[str, Any]]]): list of state of one or more metrics, + states (list[dict[str, dict[str, Any]]]): List of state of one or more metrics, normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`. Returns: - list[Metric]: list of reduced metrics + list[Metric]: List of reduced metrics Example: states = [ @@ -390,14 +438,11 @@ def __init__(self) -> None: self.rank = current_rank().rank self.logger_backends: list[LoggerBackend] = [] self._is_initialized = False - self.process_name: str | None = None async def init_backends( self, metadata_per_primary_backend: dict[str, dict[str, Any]] | None, config: dict[str, Any], - global_step: int = 0, - process_name: str | None = None, ) -> None: """A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, the backend is instantiated per-rank, in the MetricCollector, otherwise it is only instantiated @@ -407,16 +452,10 @@ async def init_backends( metadata_per_primary_backend (dict[str, dict[str, Any]] | None): Metadata from primary logger backend, e.g., {"wandb": {"run_id": "abc123"}}. config (dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. - global_step (int, default 0): Initial step for metrics. - process_name (str | None): The meaningful process name for logging. """ if self._is_initialized: - logger.debug( - f"{get_proc_name_with_rank(self.process_name)}: MetricCollector already initialized" - ) + logger.debug(f"Rank {self.rank}: MetricCollector already initialized") return - self.process_name = process_name - self.global_step = global_step # instantiate local backends if any for backend_name, backend_config in config.items(): @@ -431,9 +470,7 @@ async def init_backends( # instantiate local backend logger_backend = get_logger_backend_class(backend_name)(backend_config) await logger_backend.init( - role=BackendRole.LOCAL, - primary_logger_metadata=primary_metadata, - process_name=process_name, + role=BackendRole.LOCAL, primary_logger_metadata=primary_metadata ) self.logger_backends.append(logger_backend) @@ -461,7 +498,7 @@ def push(self, metric: Metric) -> None: "Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized." " This happens when you try to use `record_metric` before calling `init_backends`." " To disable this warning, please call in your main file:\n" - "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n" + "`mlogger = await get_or_create_metric_logger()`\n" "`await mlogger.init_backends.call_one(logging_config)`\n" "or set env variable `FORGE_DISABLE_METRICS=True`" ), @@ -490,7 +527,7 @@ async def flush( return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: - dict[str, dict[str, dict[str, Any]]]: dict of {metric_key: metric_state}, + dict[str, dict[str, dict[str, Any]]]: Dict of {metric_key: metric_state}, e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ if not self._is_initialized: @@ -507,7 +544,7 @@ async def flush( if not self.accumulators: logger.debug( - f"Collector for {get_proc_name_with_rank(self.process_name)}: No metrics to flush for global_step {global_step}" + f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for global_step {global_step}" ) return {} @@ -532,7 +569,7 @@ async def shutdown(self): """Shutdown logger_backends if initialized.""" if not self._is_initialized: logger.debug( - f"Collector for rank {get_proc_name_with_rank(self.process_name)} not initialized. Skipping shutdown" + f"Collector for {get_actor_name_with_rank()} not initialized. Skipping shutdown" ) return @@ -556,7 +593,6 @@ async def init( self, role: BackendRole, primary_logger_metadata: dict[str, Any] | None = None, - process_name: str | None = None, ) -> None: """ Initializes backend, e.g. wandb.run.init(). @@ -566,7 +602,6 @@ async def init( Can be used to behave differently for primary vs secondary roles. primary_logger_metadata (dict[str, Any] | None): From global backend for backend that required shared info, e.g. {"shared_run_id": "abc123"}. - process_name (str | None): Process name for logging. Raises: ValueError if missing metadata for shared local init. """ @@ -578,7 +613,7 @@ async def log(self, metrics: list[Metric], global_step: int) -> None: Log a batch of metrics to the backend. Args: - metrics: list of Metric objects to log. + metrics: List of Metric objects to log. global_step: Step number for x-axis alignment across metrics. """ pass @@ -601,9 +636,12 @@ async def init( self, role: BackendRole, primary_logger_metadata: dict[str, Any] | None = None, - process_name: str | None = None, ) -> None: - self.prefix = get_proc_name_with_rank(proc_name=process_name) + self.prefix = ( + get_actor_name_with_rank() + if self.logger_backend_config.get("reduce_across_ranks", True) + else "Controller" + ) async def log(self, metrics: list[Metric], global_step: int) -> None: metrics_str = "\n".join( @@ -651,13 +689,16 @@ async def init( self, role: BackendRole, primary_logger_metadata: dict[str, Any] | None = None, - process_name: str | None = None, ) -> None: if primary_logger_metadata is None: primary_logger_metadata = {} - self.name = get_proc_name_with_rank(proc_name=process_name) + self.name = ( + get_actor_name_with_rank() + if role == BackendRole.LOCAL + else "global_controller" + ) # Default global mode: only inits on controller if self.reduce_across_ranks: diff --git a/src/forge/observability/utils.py b/src/forge/observability/utils.py deleted file mode 100644 index 4a45274e3..000000000 --- a/src/forge/observability/utils.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from typing import Optional - -from monarch.actor import context - -logger = logging.getLogger(__name__) - - -def get_proc_name_with_rank(proc_name: Optional[str] = None) -> str: - """ - Returns a unique process identifier from Monarch actor context. - - Format: "ActorName_wxyz_r{rank}" where: - - ActorName: The actor class name (e.g., "TrainActor") - - wxyz: Last 4 chars of world_name (unique replica hash) - - rank: Local rank within the replica (0, 1, 2, ...) - - Note: If called from a direct proccess, defaults to "client_DPROC_r0". - - Args: - proc_name: Optional override for actor name. If None, uses actor_id.actor_name. - - Returns: - str: Unique identifier or fallback name if no context available. - """ - ctx = context() - actor_id = ctx.actor_instance.actor_id - - # Use actor_name from actor_id if not provided - if proc_name is None: - proc_name = actor_id.actor_name - - # Try to get world_name. Each replica has a unique value. - try: - world_name = actor_id.world_name - replica_id = world_name[-4:] if len(world_name) >= 4 else world_name - except BaseException: # Catches pyo3_runtime.PanicException from Rust - # Direct proc (e.g., client) - no world_name available - replica_id = "DPROC" - - # Get rank within the replica. NOT a global rank. - try: - rank = actor_id.rank - except BaseException: # Catches pyo3_runtime.PanicException from Rust - # Direct proc - no rank available - rank = 0 - - return f"{proc_name}_{replica_id}_r{rank}" diff --git a/tests/sandbox/toy_rl/toy_metrics/main.py b/tests/sandbox/toy_rl/toy_metrics/main.py index eae50c2db..57ccd97b5 100644 --- a/tests/sandbox/toy_rl/toy_metrics/main.py +++ b/tests/sandbox/toy_rl/toy_metrics/main.py @@ -95,16 +95,12 @@ async def main(): } service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False} - mlogger = await get_or_create_metric_logger(process_name="Controller") + mlogger = await get_or_create_metric_logger() await mlogger.init_backends.call_one(config) # Spawn services first (triggers registrations via provisioner hook) - trainer = await TrainActor.options( - **service_config, mesh_name="TrainActor" - ).as_service() - generator = await GeneratorActor.options( - **service_config, mesh_name="GeneratorActor" - ).as_service() + trainer = await TrainActor.options(**service_config).as_service() + generator = await GeneratorActor.options(**service_config).as_service() for i in range(3): print(f"\n=== Global Step {i} ===") diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index 7e0b22890..0d4652a6b 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -33,7 +33,7 @@ async def run(cfg: DictConfig): ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) - mlogger = await get_or_create_metric_logger(process_name="Controller") + mlogger = await get_or_create_metric_logger() await mlogger.init_backends.call_one(metric_logging_cfg) if (prompt := cfg.get("prompt")) is None: diff --git a/tests/unit_tests/observability/__init__.py b/tests/unit_tests/observability/__init__.py deleted file mode 100644 index 2e41cd717..000000000 --- a/tests/unit_tests/observability/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/tests/unit_tests/observability/conftest.py b/tests/unit_tests/observability/conftest.py index e35350c11..e8900392c 100644 --- a/tests/unit_tests/observability/conftest.py +++ b/tests/unit_tests/observability/conftest.py @@ -22,14 +22,13 @@ def __init__(self, logger_backend_config=None): self.finish_called = False self.metadata = {} - async def init(self, role="local", primary_logger_metadata=None, process_name=None): + async def init(self, role="local", primary_logger_metadata=None): self.init_called = True self.role = role self.primary_logger_metadata = primary_logger_metadata or {} - self.process_name = process_name - async def log(self, metrics, global_step): - self.logged_metrics.append((metrics, global_step)) + async def log(self, metrics, step): + self.logged_metrics.append((metrics, step)) async def finish(self): self.finish_called = True diff --git a/tests/unit_tests/observability/test_metric_actors.py b/tests/unit_tests/observability/test_metric_actors.py deleted file mode 100644 index 501e13afe..000000000 --- a/tests/unit_tests/observability/test_metric_actors.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""Optimized unit tests for metric actors functionality.""" - -import pytest - -from forge.observability.metric_actors import ( - get_or_create_metric_logger, - GlobalLoggingActor, - LocalFetcherActor, -) -from monarch.actor import this_host - - -@pytest.fixture -def global_logger(): - """Create a GlobalLoggingActor for testing.""" - p = this_host().spawn_procs(per_host={"cpus": 1}) - return p.spawn("TestGlobalLogger", GlobalLoggingActor) - - -@pytest.fixture -def local_fetcher(global_logger): - """Create a LocalFetcherActor linked to global logger.""" - p = this_host().spawn_procs(per_host={"cpus": 1}) - return p.spawn("TestLocalFetcher", LocalFetcherActor, global_logger) - - -class TestBasicOperations: - """Test basic operations for actors.""" - - @pytest.mark.asyncio - async def test_local_fetcher_flush(self, local_fetcher): - """Test LocalFetcherActor flush operations.""" - result_with_state = await local_fetcher.flush.call_one( - global_step=1, return_state=True - ) - assert result_with_state == {} - - result_without_state = await local_fetcher.flush.call_one( - global_step=1, return_state=False - ) - assert result_without_state == {} - - @pytest.mark.asyncio - async def test_global_logger_basic_ops(self, global_logger): - """Test GlobalLoggingActor basic operations.""" - count = await global_logger.get_fetcher_count.call_one() - assert count >= 0 - - has_fetcher = await global_logger.has_fetcher.call_one("nonexistent") - assert has_fetcher is False - - # Global logger flush (should not raise error) - await global_logger.flush.call_one(global_step=1) - - @pytest.mark.asyncio - async def test_backend_init(self, local_fetcher): - """Test backend initialization and shutdown.""" - metadata = {"wandb": {"shared_run_id": "test123"}} - config = {"console": {"logging_mode": "per_rank_reduce"}} - - await local_fetcher.init_backends.call_one(metadata, config, global_step=5) - await local_fetcher.shutdown.call_one() - - -class TestRegistrationLifecycle: - """Test registration lifecycle.""" - - @pytest.mark.timeout(3) - @pytest.mark.asyncio - async def test_registration_lifecycle(self, global_logger, local_fetcher): - """Test complete registration/deregistration lifecycle.""" - proc_name = "lifecycle_test_proc" - - # Initial state - initial_count = await global_logger.get_fetcher_count.call_one() - assert await global_logger.has_fetcher.call_one(proc_name) is False - - # Register - await global_logger.register_fetcher.call_one(local_fetcher, proc_name) - - # Verify registered - new_count = await global_logger.get_fetcher_count.call_one() - assert new_count == initial_count + 1 - assert await global_logger.has_fetcher.call_one(proc_name) is True - - # Deregister - await global_logger.deregister_fetcher.call_one(proc_name) - - # Verify deregistered - final_count = await global_logger.get_fetcher_count.call_one() - assert final_count == initial_count - assert await global_logger.has_fetcher.call_one(proc_name) is False - - -class TestBackendConfiguration: - """Test backend configuration validation.""" - - @pytest.mark.timeout(3) - @pytest.mark.asyncio - async def test_valid_backend_configs(self, global_logger): - """Test valid backend configurations.""" - # Empty config - await global_logger.init_backends.call_one({}) - - # Valid configs for all logging modes - for mode in ["per_rank_reduce", "per_rank_no_reduce", "global_reduce"]: - config = {"console": {"logging_mode": mode}} - await global_logger.init_backends.call_one(config) - - @pytest.mark.timeout(3) - @pytest.mark.asyncio - async def test_invalid_backend_configs(self, global_logger): - """Test invalid backend configurations are handled gracefully.""" - # Empty config should work - await global_logger.init_backends.call_one({}) - - # Config with only project should work - config_with_project = {"console": {"project": "test_project"}} - await global_logger.init_backends.call_one(config_with_project) - - # Config with reduce_across_ranks should work (Diff 3 doesn't validate logging_mode yet) - config_with_reduce = {"console": {"reduce_across_ranks": True}} - await global_logger.init_backends.call_one(config_with_reduce) - - -class TestErrorHandling: - """Test error handling scenarios.""" - - @pytest.mark.timeout(3) - @pytest.mark.asyncio - async def test_deregister_nonexistent_fetcher(self, global_logger): - """Test deregistering non-existent fetcher doesn't crash.""" - await global_logger.deregister_fetcher.call_one("nonexistent_proc") - - @pytest.mark.timeout(3) - @pytest.mark.asyncio - async def test_shutdown(self, global_logger): - """Test shutdown without issues.""" - await global_logger.shutdown.call_one() - - -class TestGetOrCreateMetricLogger: - """Test the integration function.""" - - @pytest.mark.timeout(3) - @pytest.mark.asyncio - async def test_get_or_create_functionality(self): - """Test get_or_create_metric_logger basic functionality.""" - result = await get_or_create_metric_logger(process_name="TestController") - - # Should return a GlobalLoggingActor mesh - assert result is not None - - # Should be able to call basic methods - count = await result.get_fetcher_count.call_one() - assert count >= 0 diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index d0f104459..701bda2dc 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -80,9 +80,12 @@ def test_new_enums_and_constants(self): assert isinstance(BackendRole.LOCAL, BackendRole) assert isinstance(BackendRole.GLOBAL, BackendRole) + @patch("forge.observability.metrics.get_actor_name_with_rank") @pytest.mark.asyncio - async def test_backend_role_usage(self): + async def test_backend_role_usage(self, mock_actor_name): """Test that BackendRole constants are actually used instead of string literals.""" + mock_actor_name.return_value = "TestActor_abcd_r0" + # Test ConsoleBackend console_backend = ConsoleBackend({}) await console_backend.init(role=BackendRole.LOCAL) @@ -292,8 +295,10 @@ def test_record_metric_enabled_explicit(self, mock_collector_class, mock_rank): mock_collector_class.assert_called_once() mock_collector.push.assert_called_once() - def test_wandb_backend_creation(self): + @patch("forge.observability.metrics.get_actor_name_with_rank") + def test_wandb_backend_creation(self, mock_actor_name): """Test WandbBackend creation and basic setup without WandB dependency.""" + mock_actor_name.return_value = "TestActor_abcd_r0" config = { "project": "test_project", @@ -311,9 +316,12 @@ def test_wandb_backend_creation(self): metadata = backend.get_metadata_for_secondary_ranks() assert metadata == {} # Should be empty when no run + @patch("forge.observability.metrics.get_actor_name_with_rank") @pytest.mark.asyncio - async def test_console_backend(self): + async def test_console_backend(self, mock_actor_name): """Test ConsoleBackend basic operations.""" + mock_actor_name.return_value = "TestActor_abcd_r0" + backend = ConsoleBackend({}) await backend.init(role=BackendRole.LOCAL) @@ -417,10 +425,8 @@ async def _test_fetcher_registration(self, env_var_value, should_register_fetche if hasattr(procs, "_local_fetcher"): delattr(procs, "_local_fetcher") - # Test functionality - pass explicit process_name since test bypasses provisioner - global_logger = await get_or_create_metric_logger( - proc_mesh=procs, process_name="TestProcess" - ) + # Test functionality + global_logger = await get_or_create_metric_logger(proc_mesh=procs) # Get results to check proc_has_fetcher = hasattr(procs, "_local_fetcher") diff --git a/tests/unit_tests/observability/test_utils.py b/tests/unit_tests/observability/test_utils.py deleted file mode 100644 index 9a4e24d0c..000000000 --- a/tests/unit_tests/observability/test_utils.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""Tests for observability utility functions.""" - -import pytest - -from forge.observability.utils import get_proc_name_with_rank -from monarch.actor import Actor, endpoint, this_host - - -class UtilActor(Actor): - """Actor for testing get_proc_name_with_rank in spawned context.""" - - @endpoint - async def get_name(self) -> str: - return get_proc_name_with_rank() - - @endpoint - async def get_name_with_override(self, name: str) -> str: - return get_proc_name_with_rank(proc_name=name) - - -class TestGetProcNameWithRank: - """Tests for get_proc_name_with_rank utility.""" - - def test_direct_proc(self): - """Direct proc (test process) should return client_DPROC_r0.""" - result = get_proc_name_with_rank() - assert result == "client_DPROC_r0" - - def test_direct_proc_with_override(self): - """Direct proc with override should use provided name.""" - result = get_proc_name_with_rank(proc_name="MyProcess") - assert result == "MyProcess_DPROC_r0" - - @pytest.mark.timeout(10) - @pytest.mark.asyncio - async def test_spawned_actor(self): - """Spawned actor should return ActorName_replica_rank format.""" - p = this_host().spawn_procs(per_host={"cpus": 2}) - actor = p.spawn("UtilActor", UtilActor) - - # no override - results = await actor.get_name.call() - - assert len(results) == 2 - for i, (rank_info, result) in enumerate(results): - replica_id = result.split("_")[1] - assert result == f"UtilActor_{replica_id}_r{i}" - - # override name - results = await actor.get_name_with_override.call("CustomName") - - for i, (rank_info, result) in enumerate(results): - replica_id = result.split("_")[1] - assert result == f"CustomName_{replica_id}_r{i}"