Skip to content

Commit 92326bc

Browse files
author
Felipe Mello
committed
Reapply "Metric Logging updates 4/N - better actor name (#351)" (#429)
This reverts commit 633b219.
1 parent c0bbc80 commit 92326bc

File tree

7 files changed

+185
-24
lines changed

7 files changed

+185
-24
lines changed

apps/grpo/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ async def main(cfg: DictConfig):
305305
provisioner = await init_provisioner()
306306

307307
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
308-
mlogger = await get_or_create_metric_logger()
308+
mlogger = await get_or_create_metric_logger(process_name="Controller")
309309
await mlogger.init_backends.call_one(metric_logging_cfg)
310310

311311
# ---- Setup services ---- #

src/forge/observability/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from .metrics import (
1313
BackendRole,
1414
ConsoleBackend,
15-
get_actor_name_with_rank,
16-
get_logger_backend_class,
1715
LoggerBackend,
1816
MaxAccumulator,
1917
MeanAccumulator,
@@ -29,12 +27,12 @@
2927
WandbBackend,
3028
)
3129
from .perf_tracker import trace, Tracer
30+
from .utils import get_proc_name_with_rank
3231

3332
__all__ = [
3433
# Main API functions
3534
"record_metric",
3635
"reduce_metrics_states",
37-
"get_actor_name_with_rank",
3836
"get_logger_backend_class",
3937
"get_or_create_metric_logger",
4038
# Performance tracking
@@ -45,6 +43,8 @@
4543
"BackendRole",
4644
# Enums
4745
"Reduce",
46+
# Utility functions
47+
"get_proc_name_with_rank",
4848
# Actor classes
4949
"GlobalLoggingActor",
5050
"LocalFetcherActor",

tests/sandbox/toy_rl/toy_metrics/main.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,16 @@ async def main():
9595
}
9696

9797
service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False}
98-
mlogger = await get_or_create_metric_logger()
98+
mlogger = await get_or_create_metric_logger(process_name="Controller")
9999
await mlogger.init_backends.call_one(config)
100100

101101
# Spawn services first (triggers registrations via provisioner hook)
102-
trainer = await TrainActor.options(**service_config).as_service()
103-
generator = await GeneratorActor.options(**service_config).as_service()
102+
trainer = await TrainActor.options(
103+
**service_config, mesh_name="TrainActor"
104+
).as_service()
105+
generator = await GeneratorActor.options(
106+
**service_config, mesh_name="GeneratorActor"
107+
).as_service()
104108

105109
for i in range(3):
106110
print(f"\n=== Global Step {i} ===")

tests/sandbox/vllm/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ async def run(cfg: DictConfig):
3333
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
3434
)
3535
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
36-
mlogger = await get_or_create_metric_logger()
36+
mlogger = await get_or_create_metric_logger(process_name="Controller")
3737
await mlogger.init_backends.call_one(metric_logging_cfg)
3838

3939
if (prompt := cfg.get("prompt")) is None:

tests/unit_tests/observability/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@ def __init__(self, logger_backend_config=None):
2222
self.finish_called = False
2323
self.metadata = {}
2424

25-
async def init(self, role="local", primary_logger_metadata=None):
25+
async def init(self, role="local", primary_logger_metadata=None, process_name=None):
2626
self.init_called = True
2727
self.role = role
2828
self.primary_logger_metadata = primary_logger_metadata or {}
29+
self.process_name = process_name
2930

30-
async def log(self, metrics, step):
31-
self.logged_metrics.append((metrics, step))
31+
async def log(self, metrics, global_step):
32+
self.logged_metrics.append((metrics, global_step))
3233

3334
async def finish(self):
3435
self.finish_called = True
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Optimized unit tests for metric actors functionality."""
8+
9+
import pytest
10+
11+
from forge.observability.metric_actors import (
12+
get_or_create_metric_logger,
13+
GlobalLoggingActor,
14+
LocalFetcherActor,
15+
)
16+
from monarch.actor import this_host
17+
18+
19+
@pytest.fixture
20+
def global_logger():
21+
"""Create a GlobalLoggingActor for testing."""
22+
p = this_host().spawn_procs(per_host={"cpus": 1})
23+
return p.spawn("TestGlobalLogger", GlobalLoggingActor)
24+
25+
26+
@pytest.fixture
27+
def local_fetcher(global_logger):
28+
"""Create a LocalFetcherActor linked to global logger."""
29+
p = this_host().spawn_procs(per_host={"cpus": 1})
30+
return p.spawn("TestLocalFetcher", LocalFetcherActor, global_logger)
31+
32+
33+
class TestBasicOperations:
34+
"""Test basic operations for actors."""
35+
36+
@pytest.mark.asyncio
37+
async def test_local_fetcher_flush(self, local_fetcher):
38+
"""Test LocalFetcherActor flush operations."""
39+
result_with_state = await local_fetcher.flush.call_one(
40+
global_step=1, return_state=True
41+
)
42+
assert result_with_state == {}
43+
44+
result_without_state = await local_fetcher.flush.call_one(
45+
global_step=1, return_state=False
46+
)
47+
assert result_without_state == {}
48+
49+
@pytest.mark.asyncio
50+
async def test_global_logger_basic_ops(self, global_logger):
51+
"""Test GlobalLoggingActor basic operations."""
52+
count = await global_logger.get_fetcher_count.call_one()
53+
assert count >= 0
54+
55+
has_fetcher = await global_logger.has_fetcher.call_one("nonexistent")
56+
assert has_fetcher is False
57+
58+
# Global logger flush (should not raise error)
59+
await global_logger.flush.call_one(global_step=1)
60+
61+
@pytest.mark.asyncio
62+
async def test_backend_init(self, local_fetcher):
63+
"""Test backend initialization and shutdown."""
64+
metadata = {"wandb": {"shared_run_id": "test123"}}
65+
config = {"console": {"logging_mode": "per_rank_reduce"}}
66+
67+
await local_fetcher.init_backends.call_one(metadata, config, global_step=5)
68+
await local_fetcher.shutdown.call_one()
69+
70+
71+
class TestRegistrationLifecycle:
72+
"""Test registration lifecycle."""
73+
74+
@pytest.mark.timeout(3)
75+
@pytest.mark.asyncio
76+
async def test_registration_lifecycle(self, global_logger, local_fetcher):
77+
"""Test complete registration/deregistration lifecycle."""
78+
proc_name = "lifecycle_test_proc"
79+
80+
# Initial state
81+
initial_count = await global_logger.get_fetcher_count.call_one()
82+
assert await global_logger.has_fetcher.call_one(proc_name) is False
83+
84+
# Register
85+
await global_logger.register_fetcher.call_one(local_fetcher, proc_name)
86+
87+
# Verify registered
88+
new_count = await global_logger.get_fetcher_count.call_one()
89+
assert new_count == initial_count + 1
90+
assert await global_logger.has_fetcher.call_one(proc_name) is True
91+
92+
# Deregister
93+
await global_logger.deregister_fetcher.call_one(proc_name)
94+
95+
# Verify deregistered
96+
final_count = await global_logger.get_fetcher_count.call_one()
97+
assert final_count == initial_count
98+
assert await global_logger.has_fetcher.call_one(proc_name) is False
99+
100+
101+
class TestBackendConfiguration:
102+
"""Test backend configuration validation."""
103+
104+
@pytest.mark.timeout(3)
105+
@pytest.mark.asyncio
106+
async def test_valid_backend_configs(self, global_logger):
107+
"""Test valid backend configurations."""
108+
# Empty config
109+
await global_logger.init_backends.call_one({})
110+
111+
# Valid configs for all logging modes
112+
for mode in ["per_rank_reduce", "per_rank_no_reduce", "global_reduce"]:
113+
config = {"console": {"logging_mode": mode}}
114+
await global_logger.init_backends.call_one(config)
115+
116+
@pytest.mark.timeout(3)
117+
@pytest.mark.asyncio
118+
async def test_invalid_backend_configs(self, global_logger):
119+
"""Test invalid backend configurations are handled gracefully."""
120+
# Empty config should work
121+
await global_logger.init_backends.call_one({})
122+
123+
# Config with only project should work
124+
config_with_project = {"console": {"project": "test_project"}}
125+
await global_logger.init_backends.call_one(config_with_project)
126+
127+
# Config with reduce_across_ranks should work (Diff 3 doesn't validate logging_mode yet)
128+
config_with_reduce = {"console": {"reduce_across_ranks": True}}
129+
await global_logger.init_backends.call_one(config_with_reduce)
130+
131+
132+
class TestErrorHandling:
133+
"""Test error handling scenarios."""
134+
135+
@pytest.mark.timeout(3)
136+
@pytest.mark.asyncio
137+
async def test_deregister_nonexistent_fetcher(self, global_logger):
138+
"""Test deregistering non-existent fetcher doesn't crash."""
139+
await global_logger.deregister_fetcher.call_one("nonexistent_proc")
140+
141+
@pytest.mark.timeout(3)
142+
@pytest.mark.asyncio
143+
async def test_shutdown(self, global_logger):
144+
"""Test shutdown without issues."""
145+
await global_logger.shutdown.call_one()
146+
147+
148+
class TestGetOrCreateMetricLogger:
149+
"""Test the integration function."""
150+
151+
@pytest.mark.timeout(3)
152+
@pytest.mark.asyncio
153+
async def test_get_or_create_functionality(self):
154+
"""Test get_or_create_metric_logger basic functionality."""
155+
result = await get_or_create_metric_logger(process_name="TestController")
156+
157+
# Should return a GlobalLoggingActor mesh
158+
assert result is not None
159+
160+
# Should be able to call basic methods
161+
count = await result.get_fetcher_count.call_one()
162+
assert count >= 0

tests/unit_tests/observability/test_metrics.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,9 @@ def test_new_enums_and_constants(self):
8080
assert isinstance(BackendRole.LOCAL, BackendRole)
8181
assert isinstance(BackendRole.GLOBAL, BackendRole)
8282

83-
@patch("forge.observability.metrics.get_actor_name_with_rank")
8483
@pytest.mark.asyncio
85-
async def test_backend_role_usage(self, mock_actor_name):
84+
async def test_backend_role_usage(self):
8685
"""Test that BackendRole constants are actually used instead of string literals."""
87-
mock_actor_name.return_value = "TestActor_abcd_r0"
88-
8986
# Test ConsoleBackend
9087
console_backend = ConsoleBackend({})
9188
await console_backend.init(role=BackendRole.LOCAL)
@@ -295,10 +292,8 @@ def test_record_metric_enabled_explicit(self, mock_collector_class, mock_rank):
295292
mock_collector_class.assert_called_once()
296293
mock_collector.push.assert_called_once()
297294

298-
@patch("forge.observability.metrics.get_actor_name_with_rank")
299-
def test_wandb_backend_creation(self, mock_actor_name):
295+
def test_wandb_backend_creation(self):
300296
"""Test WandbBackend creation and basic setup without WandB dependency."""
301-
mock_actor_name.return_value = "TestActor_abcd_r0"
302297

303298
config = {
304299
"project": "test_project",
@@ -316,12 +311,9 @@ def test_wandb_backend_creation(self, mock_actor_name):
316311
metadata = backend.get_metadata_for_secondary_ranks()
317312
assert metadata == {} # Should be empty when no run
318313

319-
@patch("forge.observability.metrics.get_actor_name_with_rank")
320314
@pytest.mark.asyncio
321-
async def test_console_backend(self, mock_actor_name):
315+
async def test_console_backend(self):
322316
"""Test ConsoleBackend basic operations."""
323-
mock_actor_name.return_value = "TestActor_abcd_r0"
324-
325317
backend = ConsoleBackend({})
326318

327319
await backend.init(role=BackendRole.LOCAL)
@@ -425,8 +417,10 @@ async def _test_fetcher_registration(self, env_var_value, should_register_fetche
425417
if hasattr(procs, "_local_fetcher"):
426418
delattr(procs, "_local_fetcher")
427419

428-
# Test functionality
429-
global_logger = await get_or_create_metric_logger(proc_mesh=procs)
420+
# Test functionality - pass explicit process_name since test bypasses provisioner
421+
global_logger = await get_or_create_metric_logger(
422+
proc_mesh=procs, process_name="TestProcess"
423+
)
430424

431425
# Get results to check
432426
proc_has_fetcher = hasattr(procs, "_local_fetcher")

0 commit comments

Comments
 (0)