Skip to content

Commit dbcf3b4

Browse files
committed
refactor: utils
1 parent 719b493 commit dbcf3b4

File tree

2 files changed

+21
-20
lines changed

2 files changed

+21
-20
lines changed

tests/test_optimizers.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@
2626
Shampoo,
2727
)
2828
from tests.utils import (
29-
LogisticRegression,
3029
MultiHeadLogisticRegression,
30+
build_environment,
3131
build_lookahead,
3232
dummy_closure,
3333
ids,
3434
make_dataset,
35+
tensor_to_numpy,
3536
)
3637

3738
OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
@@ -85,25 +86,6 @@
8586
]
8687

8788

88-
def tensor_to_numpy(x: torch.Tensor) -> np.ndarray:
89-
return x.detach().cpu().numpy()
90-
91-
92-
def build_environment(use_gpu: bool = False) -> Tuple[Tuple[torch.Tensor, torch.Tensor], nn.Module, nn.Module]:
93-
torch.manual_seed(42)
94-
95-
x_data, y_data = make_dataset()
96-
model: nn.Module = LogisticRegression()
97-
loss_fn: nn.Module = nn.BCEWithLogitsLoss()
98-
99-
if use_gpu and torch.cuda.is_available():
100-
x_data, y_data = x_data.cuda(), y_data.cuda()
101-
model = model.cuda()
102-
loss_fn = loss_fn.cuda()
103-
104-
return (x_data, y_data), model, loss_fn
105-
106-
10789
@pytest.mark.parametrize('optimizer_fp32_config', OPTIMIZERS, ids=ids)
10890
def test_f32_optimizers(optimizer_fp32_config):
10991
(x_data, y_data), model, loss_fn = build_environment()

tests/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,22 @@ def build_lookahead(*parameters, **kwargs):
7676

7777
def ids(v) -> str:
7878
return f'{v[0].__name__}_{v[1:]}'
79+
80+
81+
def build_environment(use_gpu: bool = False) -> Tuple[Tuple[torch.Tensor, torch.Tensor], nn.Module, nn.Module]:
82+
torch.manual_seed(42)
83+
84+
x_data, y_data = make_dataset()
85+
model: nn.Module = LogisticRegression()
86+
loss_fn: nn.Module = nn.BCEWithLogitsLoss()
87+
88+
if use_gpu and torch.cuda.is_available():
89+
x_data, y_data = x_data.cuda(), y_data.cuda()
90+
model = model.cuda()
91+
loss_fn = loss_fn.cuda()
92+
93+
return (x_data, y_data), model, loss_fn
94+
95+
96+
def tensor_to_numpy(x: torch.Tensor) -> np.ndarray:
97+
return x.detach().cpu().numpy()

0 commit comments

Comments
 (0)