Skip to content

Commit 01f8531

Browse files
authored
Refactor BoringFabric in tests (#19364)
1 parent 28b3806 commit 01f8531

File tree

7 files changed

+151
-141
lines changed

7 files changed

+151
-141
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Iterator
2+
3+
import torch
4+
from torch import Tensor
5+
from torch.utils.data import Dataset, IterableDataset
6+
7+
8+
class RandomDataset(Dataset):
9+
def __init__(self, size: int, length: int) -> None:
10+
self.len = length
11+
self.data = torch.randn(length, size)
12+
13+
def __getitem__(self, index: int) -> Tensor:
14+
return self.data[index]
15+
16+
def __len__(self) -> int:
17+
return self.len
18+
19+
20+
class RandomIterableDataset(IterableDataset):
21+
def __init__(self, size: int, count: int) -> None:
22+
self.count = count
23+
self.size = size
24+
25+
def __iter__(self) -> Iterator[Tensor]:
26+
for _ in range(self.count):
27+
yield torch.randn(self.size)

tests/tests_fabric/helpers/models.py

Lines changed: 0 additions & 76 deletions
This file was deleted.

tests/tests_fabric/strategies/test_deepspeed_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from lightning.fabric.strategies import DeepSpeedStrategy
2626
from torch.utils.data import DataLoader
2727

28-
from tests_fabric.helpers.models import RandomDataset, RandomIterableDataset
28+
from tests_fabric.helpers.datasets import RandomDataset, RandomIterableDataset
2929
from tests_fabric.helpers.runif import RunIf
3030
from tests_fabric.test_fabric import BoringModel
3131

0 commit comments

Comments
 (0)