|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +"""Optimized unit tests for metric actors functionality.""" |
| 8 | + |
| 9 | +import pytest |
| 10 | + |
| 11 | +from forge.observability.metric_actors import ( |
| 12 | + get_or_create_metric_logger, |
| 13 | + GlobalLoggingActor, |
| 14 | + LocalFetcherActor, |
| 15 | +) |
| 16 | +from monarch.actor import this_host |
| 17 | + |
| 18 | + |
| 19 | +@pytest.fixture |
| 20 | +def global_logger(): |
| 21 | + """Create a GlobalLoggingActor for testing.""" |
| 22 | + p = this_host().spawn_procs(per_host={"cpus": 1}) |
| 23 | + return p.spawn("TestGlobalLogger", GlobalLoggingActor) |
| 24 | + |
| 25 | + |
| 26 | +@pytest.fixture |
| 27 | +def local_fetcher(global_logger): |
| 28 | + """Create a LocalFetcherActor linked to global logger.""" |
| 29 | + p = this_host().spawn_procs(per_host={"cpus": 1}) |
| 30 | + return p.spawn("TestLocalFetcher", LocalFetcherActor, global_logger) |
| 31 | + |
| 32 | + |
| 33 | +class TestBasicOperations: |
| 34 | + """Test basic operations for actors.""" |
| 35 | + |
| 36 | + @pytest.mark.asyncio |
| 37 | + async def test_local_fetcher_flush(self, local_fetcher): |
| 38 | + """Test LocalFetcherActor flush operations.""" |
| 39 | + result_with_state = await local_fetcher.flush.call_one( |
| 40 | + global_step=1, return_state=True |
| 41 | + ) |
| 42 | + assert result_with_state == {} |
| 43 | + |
| 44 | + result_without_state = await local_fetcher.flush.call_one( |
| 45 | + global_step=1, return_state=False |
| 46 | + ) |
| 47 | + assert result_without_state == {} |
| 48 | + |
| 49 | + @pytest.mark.asyncio |
| 50 | + async def test_global_logger_basic_ops(self, global_logger): |
| 51 | + """Test GlobalLoggingActor basic operations.""" |
| 52 | + count = await global_logger.get_fetcher_count.call_one() |
| 53 | + assert count >= 0 |
| 54 | + |
| 55 | + has_fetcher = await global_logger.has_fetcher.call_one("nonexistent") |
| 56 | + assert has_fetcher is False |
| 57 | + |
| 58 | + # Global logger flush (should not raise error) |
| 59 | + await global_logger.flush.call_one(global_step=1) |
| 60 | + |
| 61 | + @pytest.mark.asyncio |
| 62 | + async def test_backend_init(self, local_fetcher): |
| 63 | + """Test backend initialization and shutdown.""" |
| 64 | + metadata = {"wandb": {"shared_run_id": "test123"}} |
| 65 | + config = {"console": {"logging_mode": "per_rank_reduce"}} |
| 66 | + |
| 67 | + await local_fetcher.init_backends.call_one(metadata, config, global_step=5) |
| 68 | + await local_fetcher.shutdown.call_one() |
| 69 | + |
| 70 | + |
| 71 | +class TestRegistrationLifecycle: |
| 72 | + """Test registration lifecycle.""" |
| 73 | + |
| 74 | + @pytest.mark.timeout(3) |
| 75 | + @pytest.mark.asyncio |
| 76 | + async def test_registration_lifecycle(self, global_logger, local_fetcher): |
| 77 | + """Test complete registration/deregistration lifecycle.""" |
| 78 | + proc_name = "lifecycle_test_proc" |
| 79 | + |
| 80 | + # Initial state |
| 81 | + initial_count = await global_logger.get_fetcher_count.call_one() |
| 82 | + assert await global_logger.has_fetcher.call_one(proc_name) is False |
| 83 | + |
| 84 | + # Register |
| 85 | + await global_logger.register_fetcher.call_one(local_fetcher, proc_name) |
| 86 | + |
| 87 | + # Verify registered |
| 88 | + new_count = await global_logger.get_fetcher_count.call_one() |
| 89 | + assert new_count == initial_count + 1 |
| 90 | + assert await global_logger.has_fetcher.call_one(proc_name) is True |
| 91 | + |
| 92 | + # Deregister |
| 93 | + await global_logger.deregister_fetcher.call_one(proc_name) |
| 94 | + |
| 95 | + # Verify deregistered |
| 96 | + final_count = await global_logger.get_fetcher_count.call_one() |
| 97 | + assert final_count == initial_count |
| 98 | + assert await global_logger.has_fetcher.call_one(proc_name) is False |
| 99 | + |
| 100 | + |
| 101 | +class TestBackendConfiguration: |
| 102 | + """Test backend configuration validation.""" |
| 103 | + |
| 104 | + @pytest.mark.timeout(3) |
| 105 | + @pytest.mark.asyncio |
| 106 | + async def test_valid_backend_configs(self, global_logger): |
| 107 | + """Test valid backend configurations.""" |
| 108 | + # Empty config |
| 109 | + await global_logger.init_backends.call_one({}) |
| 110 | + |
| 111 | + # Valid configs for all logging modes |
| 112 | + for mode in ["per_rank_reduce", "per_rank_no_reduce", "global_reduce"]: |
| 113 | + config = {"console": {"logging_mode": mode}} |
| 114 | + await global_logger.init_backends.call_one(config) |
| 115 | + |
| 116 | + @pytest.mark.timeout(3) |
| 117 | + @pytest.mark.asyncio |
| 118 | + async def test_invalid_backend_configs(self, global_logger): |
| 119 | + """Test invalid backend configurations are handled gracefully.""" |
| 120 | + # Empty config should work |
| 121 | + await global_logger.init_backends.call_one({}) |
| 122 | + |
| 123 | + # Config with only project should work |
| 124 | + config_with_project = {"console": {"project": "test_project"}} |
| 125 | + await global_logger.init_backends.call_one(config_with_project) |
| 126 | + |
| 127 | + # Config with reduce_across_ranks should work (Diff 3 doesn't validate logging_mode yet) |
| 128 | + config_with_reduce = {"console": {"reduce_across_ranks": True}} |
| 129 | + await global_logger.init_backends.call_one(config_with_reduce) |
| 130 | + |
| 131 | + |
| 132 | +class TestErrorHandling: |
| 133 | + """Test error handling scenarios.""" |
| 134 | + |
| 135 | + @pytest.mark.timeout(3) |
| 136 | + @pytest.mark.asyncio |
| 137 | + async def test_deregister_nonexistent_fetcher(self, global_logger): |
| 138 | + """Test deregistering non-existent fetcher doesn't crash.""" |
| 139 | + await global_logger.deregister_fetcher.call_one("nonexistent_proc") |
| 140 | + |
| 141 | + @pytest.mark.timeout(3) |
| 142 | + @pytest.mark.asyncio |
| 143 | + async def test_shutdown(self, global_logger): |
| 144 | + """Test shutdown without issues.""" |
| 145 | + await global_logger.shutdown.call_one() |
| 146 | + |
| 147 | + |
| 148 | +class TestGetOrCreateMetricLogger: |
| 149 | + """Test the integration function.""" |
| 150 | + |
| 151 | + @pytest.mark.timeout(3) |
| 152 | + @pytest.mark.asyncio |
| 153 | + async def test_get_or_create_functionality(self): |
| 154 | + """Test get_or_create_metric_logger basic functionality.""" |
| 155 | + result = await get_or_create_metric_logger(process_name="TestController") |
| 156 | + |
| 157 | + # Should return a GlobalLoggingActor mesh |
| 158 | + assert result is not None |
| 159 | + |
| 160 | + # Should be able to call basic methods |
| 161 | + count = await result.get_fetcher_count.call_one() |
| 162 | + assert count >= 0 |
0 commit comments