diff --git a/apps/grpo/main.py b/apps/grpo/main.py index ff46fea20..62d41d62c 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -323,7 +323,7 @@ async def main(cfg: DictConfig): provisioner = await init_provisioner() metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(metric_logging_cfg) # ---- Setup services ---- # diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 8f5a77f41..c83595c65 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -301,7 +301,7 @@ def bootstrap(env: dict[str, str]): if not FORGE_DISABLE_METRICS.get_value(): from forge.observability.metric_actors import get_or_create_metric_logger - _ = await get_or_create_metric_logger(procs) + _ = await get_or_create_metric_logger(procs, process_name=mesh_name) return procs async def host_mesh_from_proc(self, proc_mesh: ProcMesh): diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index b970e57fa..8efd3dace 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -12,8 +12,6 @@ from .metrics import ( BackendRole, ConsoleBackend, - get_actor_name_with_rank, - get_logger_backend_class, LoggerBackend, MaxAccumulator, MeanAccumulator, @@ -29,12 +27,12 @@ WandbBackend, ) from .perf_tracker import trace, Tracer +from .utils import get_proc_name_with_rank __all__ = [ # Main API functions "record_metric", "reduce_metrics_states", - "get_actor_name_with_rank", "get_logger_backend_class", "get_or_create_metric_logger", # Performance tracking @@ -45,6 +43,8 @@ "BackendRole", # Enums "Reduce", + # Utility functions + "get_proc_name_with_rank", # Actor classes "GlobalLoggingActor", "LocalFetcherActor", diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 2f9addfe6..23ff0a9cb 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -8,7 +8,14 @@ import logging from typing import Any, Union -from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc +from monarch.actor import ( + Actor, + context, + endpoint, + get_or_spawn_controller, + ProcMesh, + this_proc, +) from forge.env import FORGE_DISABLE_METRICS from forge.observability.metrics import ( @@ -27,6 +34,7 @@ async def get_or_create_metric_logger( proc_mesh: ProcMesh | None = None, + process_name: str | None = None, ) -> "GlobalLoggingActor": """Initializes a LocalFetcherActor in the specified process mesh (or current process if None), if not already initialized, registers it with the GlobalLoggingActor and returns the @@ -40,6 +48,9 @@ async def get_or_create_metric_logger( Args: proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, uses `monarch.actor.this_proc()`. + process_name: Optional process name (e.g., "TrainActor", "GeneratorActor") for logging. + If None, will be auto-detected from the mesh_name provided during actor initialization or + a generic mesh name if one was not provided. Returns: GlobalLoggingActor: The global logging controller. @@ -53,7 +64,7 @@ async def get_or_create_metric_logger( from forge.observability.metrics import record_metric # Main process setup - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") # Initialize logging backends await mlogger.init_backends({ @@ -66,13 +77,14 @@ async def get_or_create_metric_logger( # Training loop for step in range(max_steps): - record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN) + record_metric("loss", 1.2, reduction_type=Reduce.MEAN) # ... training code with record_metric() calls ... await mlogger.flush(step) # Log metrics for this step # Shutdown await mlogger.shutdown() """ + # Get or create the singleton global logger global _global_logger if _global_logger is None: @@ -84,6 +96,11 @@ async def get_or_create_metric_logger( # Determine process context proc = proc_mesh if proc_mesh is not None else this_proc() + # Auto-detect process_name from proc mesh if not provided + if process_name is None: + ctx = context() + process_name = ctx.actor_instance.actor_id.actor_name + # Check current state for consistency proc_has_local_fetcher = hasattr(proc, "_local_fetcher") global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc) @@ -101,7 +118,7 @@ async def get_or_create_metric_logger( # Setup local_fetcher_actor if needed (unless disabled by environment flag) if not proc_has_local_fetcher and not FORGE_DISABLE_METRICS.get_value(): local_fetcher_actor = proc.spawn( - "local_fetcher_actor", LocalFetcherActor, global_logger + "local_fetcher_actor", LocalFetcherActor, global_logger, process_name ) await global_logger.register_fetcher.call_one(local_fetcher_actor, proc) proc._local_fetcher = local_fetcher_actor # pyre-ignore @@ -117,8 +134,13 @@ class LocalFetcherActor(Actor): GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector """ - def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None: + def __init__( + self, + global_logger: Union["GlobalLoggingActor", None] = None, + process_name: str | None = None, + ) -> None: self.global_logger = global_logger + self.process_name = process_name # Passed to MetricCollector for logging _is_initialized = False @endpoint @@ -145,10 +167,22 @@ async def init_backends( self, metadata_per_primary_backend: dict[str, dict[str, Any]], config: dict[str, Any], + global_step: int = 0, ) -> None: - """Init local (per-rank) logger backends and MetricCollector.""" + """Init local (per-rank) logger backends and MetricCollector. + + Args: + metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state. + config (dict[str, Any]): Backend configurations with logging modes and settings. + global_step (int): Initial step for metrics. + """ collector = MetricCollector() - await collector.init_backends(metadata_per_primary_backend, config) + await collector.init_backends( + metadata_per_primary_backend, + config, + global_step, + process_name=self.process_name, + ) @endpoint async def shutdown(self) -> None: diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 3ce849ad2..af0c154e2 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -13,8 +13,9 @@ from typing import Any import pytz -from monarch.actor import context, current_rank +from monarch.actor import current_rank +from forge.observability.utils import get_proc_name_with_rank from forge.util.logging import log_once logger = logging.getLogger(__name__) @@ -55,6 +56,12 @@ class Metric: """Container for metric data including key, value, reduction type, and timestamp. Timestamp is automatically set to current EST time if not provided. + + Args: + key: str + value: Any + reduction: Reduce + timestamp: Optional[float] = None """ key: str @@ -68,55 +75,6 @@ def __post_init__(self): self.timestamp = datetime.now(pytz.UTC).timestamp() -def get_actor_name_with_rank() -> str: - """ - Extracts actor information from Monarch context to form a logging name. - - Returns: - str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0"). - Falls back to "UnknownActor" if context unavailable. - """ - # Add more defensive checks - ctx = context() - if ctx is None or ctx.actor_instance is None: - logger.warning("Context unavailable, using fallback actor name for logging.") - return "UnknownActor" - - actor_instance = ctx.actor_instance - rank = current_rank() - - actor_id_full = str(actor_instance.actor_id) - - # Parse the actor_id - parts = actor_id_full.split(".") - rank_name = "UnknownActor" # fallback - if len(parts) >= 2: - world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" - actor_part = parts[1] # e.g., "TestActorConfigured[0]" - - # Extract world ID and proc rank - world_id = world_part.split("[")[0] if "[" in world_part else world_part - - # Extract clean actor name (remove "Configured" suffix if present) - if "[" in actor_part: - actor_name = actor_part.split("[")[0] # e.g., "TestActorConfigured" - if actor_name.endswith("Configured"): - actor_name = actor_name[:-10] # Remove "Configured" - else: - actor_name = actor_part - - # Use last 4 characters of world_id as replica identifier - # This is deterministic, readable, and works for any number of replicas - replica_id = world_id[-4:] if len(world_id) >= 4 else world_id - - # Use current_rank().rank as the local rank within the replica - local_rank = rank.rank - - rank_name = f"{actor_name}_{replica_id}_r{local_rank}" - - return rank_name - - def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: """Thin wrapper to send metrics to per-rank local MetricCollectors. @@ -150,11 +108,11 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri states is more precise than merging locally reduced metrics. Args: - states (list[dict[str, dict[str, Any]]]): List of state of one or more metrics, + states (list[dict[str, dict[str, Any]]]): list of state of one or more metrics, normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`. Returns: - list[Metric]: List of reduced metrics + list[Metric]: list of reduced metrics Example: states = [ @@ -438,11 +396,14 @@ def __init__(self) -> None: self.rank = current_rank().rank self.logger_backends: list[LoggerBackend] = [] self._is_initialized = False + self.process_name: str | None = None async def init_backends( self, metadata_per_primary_backend: dict[str, dict[str, Any]] | None, config: dict[str, Any], + global_step: int = 0, + process_name: str | None = None, ) -> None: """A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, the backend is instantiated per-rank, in the MetricCollector, otherwise it is only instantiated @@ -452,10 +413,16 @@ async def init_backends( metadata_per_primary_backend (dict[str, dict[str, Any]] | None): Metadata from primary logger backend, e.g., {"wandb": {"run_id": "abc123"}}. config (dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. + global_step (int, default 0): Initial step for metrics. + process_name (str | None): The meaningful process name for logging. """ if self._is_initialized: - logger.debug(f"Rank {self.rank}: MetricCollector already initialized") + logger.debug( + f"{get_proc_name_with_rank(self.process_name)}: MetricCollector already initialized" + ) return + self.process_name = process_name + self.global_step = global_step # instantiate local backends if any for backend_name, backend_config in config.items(): @@ -470,7 +437,9 @@ async def init_backends( # instantiate local backend logger_backend = get_logger_backend_class(backend_name)(backend_config) await logger_backend.init( - role=BackendRole.LOCAL, primary_logger_metadata=primary_metadata + role=BackendRole.LOCAL, + primary_logger_metadata=primary_metadata, + process_name=process_name, ) self.logger_backends.append(logger_backend) @@ -498,7 +467,7 @@ def push(self, metric: Metric) -> None: "Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized." " This happens when you try to use `record_metric` before calling `init_backends`." " To disable this warning, please call in your main file:\n" - "`mlogger = await get_or_create_metric_logger()`\n" + "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n" "`await mlogger.init_backends.call_one(logging_config)`\n" "or set env variable `FORGE_DISABLE_METRICS=True`" ), @@ -527,7 +496,7 @@ async def flush( return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: - dict[str, dict[str, dict[str, Any]]]: Dict of {metric_key: metric_state}, + dict[str, dict[str, dict[str, Any]]]: dict of {metric_key: metric_state}, e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ if not self._is_initialized: @@ -544,7 +513,7 @@ async def flush( if not self.accumulators: logger.debug( - f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for global_step {global_step}" + f"Collector for {get_proc_name_with_rank(self.process_name)}: No metrics to flush for global_step {global_step}" ) return {} @@ -569,7 +538,7 @@ async def shutdown(self): """Shutdown logger_backends if initialized.""" if not self._is_initialized: logger.debug( - f"Collector for {get_actor_name_with_rank()} not initialized. Skipping shutdown" + f"Collector for rank {get_proc_name_with_rank(self.process_name)} not initialized. Skipping shutdown" ) return @@ -593,6 +562,7 @@ async def init( self, role: BackendRole, primary_logger_metadata: dict[str, Any] | None = None, + process_name: str | None = None, ) -> None: """ Initializes backend, e.g. wandb.run.init(). @@ -602,6 +572,7 @@ async def init( Can be used to behave differently for primary vs secondary roles. primary_logger_metadata (dict[str, Any] | None): From global backend for backend that required shared info, e.g. {"shared_run_id": "abc123"}. + process_name (str | None): Process name for logging. Raises: ValueError if missing metadata for shared local init. """ @@ -613,7 +584,7 @@ async def log(self, metrics: list[Metric], global_step: int) -> None: Log a batch of metrics to the backend. Args: - metrics: List of Metric objects to log. + metrics: list of Metric objects to log. global_step: Step number for x-axis alignment across metrics. """ pass @@ -636,12 +607,9 @@ async def init( self, role: BackendRole, primary_logger_metadata: dict[str, Any] | None = None, + process_name: str | None = None, ) -> None: - self.prefix = ( - get_actor_name_with_rank() - if self.logger_backend_config.get("reduce_across_ranks", True) - else "Controller" - ) + self.prefix = get_proc_name_with_rank(proc_name=process_name) async def log(self, metrics: list[Metric], global_step: int) -> None: metrics_str = "\n".join( @@ -689,16 +657,13 @@ async def init( self, role: BackendRole, primary_logger_metadata: dict[str, Any] | None = None, + process_name: str | None = None, ) -> None: if primary_logger_metadata is None: primary_logger_metadata = {} - self.name = ( - get_actor_name_with_rank() - if role == BackendRole.LOCAL - else "global_controller" - ) + self.name = get_proc_name_with_rank(proc_name=process_name) # Default global mode: only inits on controller if self.reduce_across_ranks: diff --git a/src/forge/observability/utils.py b/src/forge/observability/utils.py new file mode 100644 index 000000000..4a45274e3 --- /dev/null +++ b/src/forge/observability/utils.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +from monarch.actor import context + +logger = logging.getLogger(__name__) + + +def get_proc_name_with_rank(proc_name: Optional[str] = None) -> str: + """ + Returns a unique process identifier from Monarch actor context. + + Format: "ActorName_wxyz_r{rank}" where: + - ActorName: The actor class name (e.g., "TrainActor") + - wxyz: Last 4 chars of world_name (unique replica hash) + - rank: Local rank within the replica (0, 1, 2, ...) + + Note: If called from a direct proccess, defaults to "client_DPROC_r0". + + Args: + proc_name: Optional override for actor name. If None, uses actor_id.actor_name. + + Returns: + str: Unique identifier or fallback name if no context available. + """ + ctx = context() + actor_id = ctx.actor_instance.actor_id + + # Use actor_name from actor_id if not provided + if proc_name is None: + proc_name = actor_id.actor_name + + # Try to get world_name. Each replica has a unique value. + try: + world_name = actor_id.world_name + replica_id = world_name[-4:] if len(world_name) >= 4 else world_name + except BaseException: # Catches pyo3_runtime.PanicException from Rust + # Direct proc (e.g., client) - no world_name available + replica_id = "DPROC" + + # Get rank within the replica. NOT a global rank. + try: + rank = actor_id.rank + except BaseException: # Catches pyo3_runtime.PanicException from Rust + # Direct proc - no rank available + rank = 0 + + return f"{proc_name}_{replica_id}_r{rank}" diff --git a/tests/sandbox/toy_rl/toy_metrics/main.py b/tests/sandbox/toy_rl/toy_metrics/main.py index d999fb700..29164b38f 100644 --- a/tests/sandbox/toy_rl/toy_metrics/main.py +++ b/tests/sandbox/toy_rl/toy_metrics/main.py @@ -95,12 +95,16 @@ async def main(): } service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False} - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(config) # Spawn services first (triggers registrations via provisioner hook) - trainer = await TrainActor.options(**service_config).as_service() - generator = await GeneratorActor.options(**service_config).as_service() + trainer = await TrainActor.options( + **service_config, mesh_name="TrainActor" + ).as_service() + generator = await GeneratorActor.options( + **service_config, mesh_name="GeneratorActor" + ).as_service() for i in range(3): print(f"\n=== Global Step {i} ===") diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index 54b093841..d6a63c31a 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -33,7 +33,7 @@ async def run(cfg: DictConfig): ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(metric_logging_cfg) if (prompt := cfg.get("prompt")) is None: diff --git a/tests/unit_tests/observability/__init__.py b/tests/unit_tests/observability/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/tests/unit_tests/observability/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/unit_tests/observability/conftest.py b/tests/unit_tests/observability/conftest.py index e8900392c..e35350c11 100644 --- a/tests/unit_tests/observability/conftest.py +++ b/tests/unit_tests/observability/conftest.py @@ -22,13 +22,14 @@ def __init__(self, logger_backend_config=None): self.finish_called = False self.metadata = {} - async def init(self, role="local", primary_logger_metadata=None): + async def init(self, role="local", primary_logger_metadata=None, process_name=None): self.init_called = True self.role = role self.primary_logger_metadata = primary_logger_metadata or {} + self.process_name = process_name - async def log(self, metrics, step): - self.logged_metrics.append((metrics, step)) + async def log(self, metrics, global_step): + self.logged_metrics.append((metrics, global_step)) async def finish(self): self.finish_called = True diff --git a/tests/unit_tests/observability/test_metric_actors.py b/tests/unit_tests/observability/test_metric_actors.py new file mode 100644 index 000000000..501e13afe --- /dev/null +++ b/tests/unit_tests/observability/test_metric_actors.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Optimized unit tests for metric actors functionality.""" + +import pytest + +from forge.observability.metric_actors import ( + get_or_create_metric_logger, + GlobalLoggingActor, + LocalFetcherActor, +) +from monarch.actor import this_host + + +@pytest.fixture +def global_logger(): + """Create a GlobalLoggingActor for testing.""" + p = this_host().spawn_procs(per_host={"cpus": 1}) + return p.spawn("TestGlobalLogger", GlobalLoggingActor) + + +@pytest.fixture +def local_fetcher(global_logger): + """Create a LocalFetcherActor linked to global logger.""" + p = this_host().spawn_procs(per_host={"cpus": 1}) + return p.spawn("TestLocalFetcher", LocalFetcherActor, global_logger) + + +class TestBasicOperations: + """Test basic operations for actors.""" + + @pytest.mark.asyncio + async def test_local_fetcher_flush(self, local_fetcher): + """Test LocalFetcherActor flush operations.""" + result_with_state = await local_fetcher.flush.call_one( + global_step=1, return_state=True + ) + assert result_with_state == {} + + result_without_state = await local_fetcher.flush.call_one( + global_step=1, return_state=False + ) + assert result_without_state == {} + + @pytest.mark.asyncio + async def test_global_logger_basic_ops(self, global_logger): + """Test GlobalLoggingActor basic operations.""" + count = await global_logger.get_fetcher_count.call_one() + assert count >= 0 + + has_fetcher = await global_logger.has_fetcher.call_one("nonexistent") + assert has_fetcher is False + + # Global logger flush (should not raise error) + await global_logger.flush.call_one(global_step=1) + + @pytest.mark.asyncio + async def test_backend_init(self, local_fetcher): + """Test backend initialization and shutdown.""" + metadata = {"wandb": {"shared_run_id": "test123"}} + config = {"console": {"logging_mode": "per_rank_reduce"}} + + await local_fetcher.init_backends.call_one(metadata, config, global_step=5) + await local_fetcher.shutdown.call_one() + + +class TestRegistrationLifecycle: + """Test registration lifecycle.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_registration_lifecycle(self, global_logger, local_fetcher): + """Test complete registration/deregistration lifecycle.""" + proc_name = "lifecycle_test_proc" + + # Initial state + initial_count = await global_logger.get_fetcher_count.call_one() + assert await global_logger.has_fetcher.call_one(proc_name) is False + + # Register + await global_logger.register_fetcher.call_one(local_fetcher, proc_name) + + # Verify registered + new_count = await global_logger.get_fetcher_count.call_one() + assert new_count == initial_count + 1 + assert await global_logger.has_fetcher.call_one(proc_name) is True + + # Deregister + await global_logger.deregister_fetcher.call_one(proc_name) + + # Verify deregistered + final_count = await global_logger.get_fetcher_count.call_one() + assert final_count == initial_count + assert await global_logger.has_fetcher.call_one(proc_name) is False + + +class TestBackendConfiguration: + """Test backend configuration validation.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_valid_backend_configs(self, global_logger): + """Test valid backend configurations.""" + # Empty config + await global_logger.init_backends.call_one({}) + + # Valid configs for all logging modes + for mode in ["per_rank_reduce", "per_rank_no_reduce", "global_reduce"]: + config = {"console": {"logging_mode": mode}} + await global_logger.init_backends.call_one(config) + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_invalid_backend_configs(self, global_logger): + """Test invalid backend configurations are handled gracefully.""" + # Empty config should work + await global_logger.init_backends.call_one({}) + + # Config with only project should work + config_with_project = {"console": {"project": "test_project"}} + await global_logger.init_backends.call_one(config_with_project) + + # Config with reduce_across_ranks should work (Diff 3 doesn't validate logging_mode yet) + config_with_reduce = {"console": {"reduce_across_ranks": True}} + await global_logger.init_backends.call_one(config_with_reduce) + + +class TestErrorHandling: + """Test error handling scenarios.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_deregister_nonexistent_fetcher(self, global_logger): + """Test deregistering non-existent fetcher doesn't crash.""" + await global_logger.deregister_fetcher.call_one("nonexistent_proc") + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_shutdown(self, global_logger): + """Test shutdown without issues.""" + await global_logger.shutdown.call_one() + + +class TestGetOrCreateMetricLogger: + """Test the integration function.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_get_or_create_functionality(self): + """Test get_or_create_metric_logger basic functionality.""" + result = await get_or_create_metric_logger(process_name="TestController") + + # Should return a GlobalLoggingActor mesh + assert result is not None + + # Should be able to call basic methods + count = await result.get_fetcher_count.call_one() + assert count >= 0 diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index 701bda2dc..d0f104459 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -80,12 +80,9 @@ def test_new_enums_and_constants(self): assert isinstance(BackendRole.LOCAL, BackendRole) assert isinstance(BackendRole.GLOBAL, BackendRole) - @patch("forge.observability.metrics.get_actor_name_with_rank") @pytest.mark.asyncio - async def test_backend_role_usage(self, mock_actor_name): + async def test_backend_role_usage(self): """Test that BackendRole constants are actually used instead of string literals.""" - mock_actor_name.return_value = "TestActor_abcd_r0" - # Test ConsoleBackend console_backend = ConsoleBackend({}) await console_backend.init(role=BackendRole.LOCAL) @@ -295,10 +292,8 @@ def test_record_metric_enabled_explicit(self, mock_collector_class, mock_rank): mock_collector_class.assert_called_once() mock_collector.push.assert_called_once() - @patch("forge.observability.metrics.get_actor_name_with_rank") - def test_wandb_backend_creation(self, mock_actor_name): + def test_wandb_backend_creation(self): """Test WandbBackend creation and basic setup without WandB dependency.""" - mock_actor_name.return_value = "TestActor_abcd_r0" config = { "project": "test_project", @@ -316,12 +311,9 @@ def test_wandb_backend_creation(self, mock_actor_name): metadata = backend.get_metadata_for_secondary_ranks() assert metadata == {} # Should be empty when no run - @patch("forge.observability.metrics.get_actor_name_with_rank") @pytest.mark.asyncio - async def test_console_backend(self, mock_actor_name): + async def test_console_backend(self): """Test ConsoleBackend basic operations.""" - mock_actor_name.return_value = "TestActor_abcd_r0" - backend = ConsoleBackend({}) await backend.init(role=BackendRole.LOCAL) @@ -425,8 +417,10 @@ async def _test_fetcher_registration(self, env_var_value, should_register_fetche if hasattr(procs, "_local_fetcher"): delattr(procs, "_local_fetcher") - # Test functionality - global_logger = await get_or_create_metric_logger(proc_mesh=procs) + # Test functionality - pass explicit process_name since test bypasses provisioner + global_logger = await get_or_create_metric_logger( + proc_mesh=procs, process_name="TestProcess" + ) # Get results to check proc_has_fetcher = hasattr(procs, "_local_fetcher") diff --git a/tests/unit_tests/observability/test_utils.py b/tests/unit_tests/observability/test_utils.py new file mode 100644 index 000000000..9a4e24d0c --- /dev/null +++ b/tests/unit_tests/observability/test_utils.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for observability utility functions.""" + +import pytest + +from forge.observability.utils import get_proc_name_with_rank +from monarch.actor import Actor, endpoint, this_host + + +class UtilActor(Actor): + """Actor for testing get_proc_name_with_rank in spawned context.""" + + @endpoint + async def get_name(self) -> str: + return get_proc_name_with_rank() + + @endpoint + async def get_name_with_override(self, name: str) -> str: + return get_proc_name_with_rank(proc_name=name) + + +class TestGetProcNameWithRank: + """Tests for get_proc_name_with_rank utility.""" + + def test_direct_proc(self): + """Direct proc (test process) should return client_DPROC_r0.""" + result = get_proc_name_with_rank() + assert result == "client_DPROC_r0" + + def test_direct_proc_with_override(self): + """Direct proc with override should use provided name.""" + result = get_proc_name_with_rank(proc_name="MyProcess") + assert result == "MyProcess_DPROC_r0" + + @pytest.mark.timeout(10) + @pytest.mark.asyncio + async def test_spawned_actor(self): + """Spawned actor should return ActorName_replica_rank format.""" + p = this_host().spawn_procs(per_host={"cpus": 2}) + actor = p.spawn("UtilActor", UtilActor) + + # no override + results = await actor.get_name.call() + + assert len(results) == 2 + for i, (rank_info, result) in enumerate(results): + replica_id = result.split("_")[1] + assert result == f"UtilActor_{replica_id}_r{i}" + + # override name + results = await actor.get_name_with_override.call("CustomName") + + for i, (rank_info, result) in enumerate(results): + replica_id = result.split("_")[1] + assert result == f"CustomName_{replica_id}_r{i}"