Skip to content

Commit c6be26f

Browse files
committed
added test docstrings
1 parent e77895d commit c6be26f

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

tests/test_smoke_training.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,40 @@
77

88
class DummyAccelerator:
99
def __init__(self):
10+
"""Initialize a CPU-only dummy accelerator."""
1011
self.device = torch.device("cpu")
1112

1213
def autocast(self):
14+
"""Provide a no-op context manager for mixed precision."""
1315
from contextlib import nullcontext
1416

1517
return nullcontext()
1618

1719
def backward(self, loss):
20+
"""Perform a standard backward pass on the given loss."""
1821
loss.backward()
1922

2023
def clip_grad_norm_(self, params, max_norm):
24+
"""Clip gradients of params to the given max_norm."""
2125
torch.nn.utils.clip_grad_norm_(params, max_norm)
2226

2327

2428
class DummyFID:
2529
def reset(self):
30+
"""Reset dummy FID state (no-op)."""
2631
pass
2732

2833
def update(self, *args, **kwargs):
34+
"""Update dummy FID with new data (no-op)."""
2935
pass
3036

3137
def compute(self):
38+
"""Return a zero tensor as the dummy FID score."""
3239
return torch.tensor(0.0)
3340

3441

3542
def test_smoke_training(tmp_path, monkeypatch):
43+
"""Run a tiny end-to-end training epoch to verify nothing breaks."""
3644
root = create_utk_dataset(tmp_path)
3745
transform = T.Compose([T.ToTensor()])
3846
train_loader = data.make_unpaired_loader(

0 commit comments

Comments
 (0)