Skip to content

Commit 067a8b0

Browse files
committed
test
1 parent 4c2e953 commit 067a8b0

File tree

7 files changed

+36
-38
lines changed

7 files changed

+36
-38
lines changed

.flake8

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ max-line-length = 120
88
# N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
99
# E731 allow usage of assigning lambda expressions
1010
# N803,N806 allow caps and mixed case in function params. This is to work with Triton kernel coding style.
11+
# E704 ignored to allow black's formatting of Protocol stub methods (def method(self) -> None: ...)
1112
ignore =
12-
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,N803,N806
13+
E203,E305,E402,E501,E704,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,N803,N806
1314
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
1415
# to line this up with executable bit
1516
EXE001,

pyproject.toml

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -100,23 +100,9 @@ prerelease = "allow"
100100
environments = [
101101
"sys_platform == 'linux'",
102102
]
103-
# override-dependencies = ["torch>2.7.1", "torchaudio>=2.7.1", "torchvision>=0.22.0"]
104-
105-
[tool.ufmt]
106-
formatter = "ruff-api"
107-
sorter = "usort"
108103

109104
[tool.black]
110-
target-version = ["py310"]
105+
target-version = ["py310"] # match the minium supported python version
111106

112107
[tool.usort]
113108
first_party_detection = false
114-
115-
[tool.ruff]
116-
target-version = "py310"
117-
118-
[tool.ruff.lint.isort]
119-
case-sensitive = false
120-
combine-as-imports = true
121-
detect-same-package = false
122-
order-by-type = false

src/forge/observability/perf_tracker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,9 @@ def __init__(self, max_workers: int = 2) -> None:
276276
if not torch.cuda.is_available():
277277
raise RuntimeError("CUDA is not available for timing")
278278
self._executor = ThreadPoolExecutor(max_workers=max_workers)
279-
self._futures: list[
280-
tuple[str, Future[float], int]
281-
] = [] # (name, future, submission_index)
279+
self._futures: list[tuple[str, Future[float], int]] = (
280+
[]
281+
) # (name, future, submission_index)
282282
self._durations: list[tuple[str, float]] = []
283283
self._chain_start: torch.cuda.Event | None = None
284284

tests/unit_tests/datasets/test_hf.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,12 @@ def test_shuffling_behavior(self, dataset_factory, small_dataset_file):
229229
), f"Shuffled epochs should be shuffled differently, got {first_epoch_ids} and {second_epoch_ids}"
230230

231231
# But should contain the same set of IDs
232-
assert (
233-
set(first_epoch_ids) == set(range(SMALL_DATASET_SIZE))
234-
), f"First epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {first_epoch_ids}"
235-
assert (
236-
set(second_epoch_ids) == set(range(SMALL_DATASET_SIZE))
237-
), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE-1}), got {second_epoch_ids}"
232+
assert set(first_epoch_ids) == set(
233+
range(SMALL_DATASET_SIZE)
234+
), f"First epoch samples should be (0-{SMALL_DATASET_SIZE - 1}), got {first_epoch_ids}"
235+
assert set(second_epoch_ids) == set(
236+
range(SMALL_DATASET_SIZE)
237+
), f"Second epoch samples should be (0-{SMALL_DATASET_SIZE - 1}), got {second_epoch_ids}"
238238

239239
def test_epoch_tracking(self, dataset_factory, small_dataset_file):
240240
"""Test that epoch number is correctly tracked across dataset restarts."""

tests/unit_tests/observability/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ def mock_rank():
4747
@pytest.fixture
4848
def mock_actor_context():
4949
"""Mock Monarch actor context for testing actor name generation."""
50-
with patch("forge.observability.metrics.context") as mock_context, patch(
51-
"forge.observability.metrics.current_rank"
52-
) as mock_rank:
50+
with (
51+
patch("forge.observability.metrics.context") as mock_context,
52+
patch("forge.observability.metrics.current_rank") as mock_rank,
53+
):
5354
# Setup mock context
5455
ctx = MagicMock()
5556
actor_instance = MagicMock()

tests/unit_tests/observability/test_metrics.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,11 @@ async def test_backend_role_usage(self):
9494
)
9595

9696
# Mock all the WandB init methods to focus only on role validation
97-
with patch.object(wandb_backend, "_init_global"), patch.object(
98-
wandb_backend, "_init_shared_global"
99-
), patch.object(wandb_backend, "_init_shared_local"), patch.object(
100-
wandb_backend, "_init_per_rank"
97+
with (
98+
patch.object(wandb_backend, "_init_global"),
99+
patch.object(wandb_backend, "_init_shared_global"),
100+
patch.object(wandb_backend, "_init_shared_local"),
101+
patch.object(wandb_backend, "_init_per_rank"),
101102
):
102103
# Should not raise error for valid roles (type system prevents invalid values)
103104
await wandb_backend.init(role=BackendRole.GLOBAL)
@@ -433,11 +434,19 @@ async def _test_fetcher_registration(self, env_var_value, should_register_fetche
433434

434435
# Assert based on expected behavior
435436
if should_register_fetchers:
436-
assert proc_has_fetcher, f"Expected process to have _local_fetcher when FORGE_DISABLE_METRICS={env_var_value}"
437-
assert global_has_fetcher, f"Expected global logger to have fetcher registered when FORGE_DISABLE_METRICS={env_var_value}"
437+
assert (
438+
proc_has_fetcher
439+
), f"Expected process to have _local_fetcher when FORGE_DISABLE_METRICS={env_var_value}"
440+
assert (
441+
global_has_fetcher
442+
), f"Expected global logger to have fetcher registered when FORGE_DISABLE_METRICS={env_var_value}"
438443
else:
439-
assert not proc_has_fetcher, f"Expected process to NOT have _local_fetcher when FORGE_DISABLE_METRICS={env_var_value}"
440-
assert not global_has_fetcher, f"Expected global logger to NOT have fetcher registered when FORGE_DISABLE_METRICS={env_var_value}"
444+
assert (
445+
not proc_has_fetcher
446+
), f"Expected process to NOT have _local_fetcher when FORGE_DISABLE_METRICS={env_var_value}"
447+
assert (
448+
not global_has_fetcher
449+
), f"Expected global logger to NOT have fetcher registered when FORGE_DISABLE_METRICS={env_var_value}"
441450

442451
@pytest.mark.asyncio
443452
@pytest.mark.parametrize(

tests/unit_tests/observability/test_perf_tracker.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,9 @@ def test_metric_timer_uses_gpu_override(
389389
if env_value == "true" and not torch.cuda.is_available():
390390
pytest.skip("CUDA not available")
391391

392-
with patch("torch.cuda.is_available", return_value=True), patch(
393-
"forge.observability.perf_tracker.record_metric"
392+
with (
393+
patch("torch.cuda.is_available", return_value=True),
394+
patch("forge.observability.perf_tracker.record_metric"),
394395
):
395396
monkeypatch.setenv(METRIC_TIMER_USES_GPU.name, env_value)
396397

0 commit comments

Comments
 (0)