Skip to content

Commit 40c7100

Browse files
authored
Test fabric (#193)
* refactor apis * update * fix * update * update * fix test
1 parent b089b1f commit 40c7100

File tree

7 files changed

+107
-41
lines changed

7 files changed

+107
-41
lines changed

examples/src/models/hello_world.py

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,50 +11,94 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
# Source code inspired from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
15+
import matplotlib.pyplot as plt
16+
import numpy as np
17+
import torch
18+
import torch.nn.functional as F
19+
import torch.optim as optim
1520
import torchvision
16-
from timm import create_model
21+
import torchvision.transforms as transforms
22+
from torch import nn
1723
from torch.utils.data import DataLoader
1824
from torchmetrics.classification import MulticlassAccuracy
19-
from torchvision import transforms as T
2025

2126
from gradsflow import AutoDataset, Model
22-
from gradsflow.callbacks import (
23-
CometCallback,
24-
CSVLogger,
25-
EmissionTrackerCallback,
26-
ModelCheckpoint,
27-
WandbCallback,
28-
)
29-
from gradsflow.data.common import random_split_dataset
30-
31-
# Replace dataloaders with your custom dataset and you are all set to train your model
27+
from gradsflow.callbacks import CSVLogger, ModelCheckpoint
28+
29+
# Replace dataloaders with your custom dataset, and you are all set to train your model
3230
image_size = (64, 64)
3331
batch_size = 4
3432

35-
to_rgb = lambda x: x.convert("RGB")
33+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
34+
35+
trainset = torchvision.datasets.CIFAR10(root="~/data", train=True, download=True, transform=transform)
36+
train_dl = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
3637

37-
augs = T.Compose([to_rgb, T.AutoAugment(), T.Resize(image_size), T.ToTensor()])
38-
data = torchvision.datasets.CIFAR10("~/data", download=True, transform=augs)
39-
train_data, val_data = random_split_dataset(data, 0.99)
40-
train_dl = DataLoader(train_data, batch_size=batch_size)
41-
val_dl = DataLoader(val_data, batch_size=batch_size)
42-
num_classes = len(data.classes)
38+
testset = torchvision.datasets.CIFAR10(root="~/data", train=False, download=True, transform=transform)
39+
val_dl = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
40+
num_classes = len(trainset.classes)
4341
cbs = [
4442
CSVLogger(
4543
verbose=True,
4644
),
4745
ModelCheckpoint(),
48-
EmissionTrackerCallback(),
46+
# EmissionTrackerCallback(),
4947
# CometCallback(offline=True),
50-
WandbCallback(),
48+
# WandbCallback(),
5149
]
5250

51+
52+
def imshow(img):
53+
img = img / 2 + 0.5 # unnormalize
54+
npimg = img.numpy()
55+
plt.imshow(np.transpose(npimg, (1, 2, 0)))
56+
plt.show()
57+
58+
59+
class Net(nn.Module):
60+
def __init__(self):
61+
super().__init__()
62+
self.conv1 = nn.Conv2d(3, 6, 5)
63+
self.pool = nn.MaxPool2d(2, 2)
64+
self.conv2 = nn.Conv2d(6, 16, 5)
65+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
66+
self.fc2 = nn.Linear(120, 84)
67+
self.fc3 = nn.Linear(84, 10)
68+
69+
def forward(self, x):
70+
x = self.pool(F.relu(self.conv1(x)))
71+
x = self.pool(F.relu(self.conv2(x)))
72+
x = torch.flatten(x, 1) # flatten all dimensions except batch
73+
x = F.relu(self.fc1(x))
74+
x = F.relu(self.fc2(x))
75+
x = self.fc3(x)
76+
return x
77+
78+
5379
if __name__ == "__main__":
5480
autodataset = AutoDataset(train_dl, val_dl, num_classes=num_classes)
55-
cnn = create_model("resnet18", pretrained=False, num_classes=num_classes)
81+
net = Net()
82+
model = Model(net)
83+
criterion = nn.CrossEntropyLoss()
84+
85+
model.compile(
86+
criterion,
87+
optim.SGD,
88+
optimizer_config={"momentum": 0.9},
89+
learning_rate=0.001,
90+
metrics=[MulticlassAccuracy(autodataset.num_classes)],
91+
)
92+
model.fit(autodataset, max_epochs=2, callbacks=cbs)
93+
94+
dataiter = iter(val_dl)
95+
images, labels = next(dataiter)
96+
97+
# print images
98+
# imshow(torchvision.utils.make_grid(images))
99+
print("GroundTruth: ", " ".join(f"{trainset.classes[labels[j]]:5s}" for j in range(4)))
56100

57-
model = Model(cnn)
101+
outputs = net(images)
102+
_, predicted = torch.max(outputs, 1)
58103

59-
model.compile("crossentropyloss", "adam", metrics=[MulticlassAccuracy(autodataset.num_classes)])
60-
model.fit(autodataset, max_epochs=10, steps_per_epoch=10, callbacks=cbs)
104+
print("Predicted: ", " ".join(f"{trainset.classes[predicted[j]]:5s}" for j in range(4)))

gradsflow/data/autodata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def _fetch(self, data, device_mapper: Optional[Callable] = None):
151151
if self.device_setup_status:
152152
return data
153153
if device_mapper:
154-
data = map(device_mapper, data, self._default_device)
154+
data = map(device_mapper, data)
155155
return data
156156

157157
@property

gradsflow/data/mixins.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515

1616
import torch
1717

18+
from gradsflow.utility import default_device
19+
1820

1921
class DataMixin:
2022
INPUT_KEY = 0 # other common value - inputs, images, text
2123
OUTPUT_KEY = 1 # other common values - target, ground
24+
device = default_device()
2225

2326
def fetch_inputs(self, data: Union[List, Dict]):
2427
return data[self.INPUT_KEY]

gradsflow/models/base.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import logging
1514
import os
1615
from dataclasses import dataclass
1716
from typing import Any, Callable, List, Optional, Union
1817

18+
import lightning as L
1919
import smart_open
2020
import torch
2121
from lightning.fabric import Fabric
@@ -111,30 +111,39 @@ def __init__(
111111
self,
112112
learner: Union[nn.Module, Any],
113113
device: Optional[str] = "auto",
114+
strategy: Optional[str] = None,
115+
precision: Any = 32,
116+
num_nodes: int = 1,
114117
use_accelerator: bool = True,
115118
accelerator_config: dict = None,
116119
):
117-
self.accelerator = None
120+
self._accelerator: L.Fabric = None
118121
super().__init__()
119-
self._set_accelerator(device, use_accelerator, accelerator_config)
122+
self._set_accelerator(device, strategy, precision, num_nodes, use_accelerator, accelerator_config)
120123
self._learner = learner
121124

122-
def _set_accelerator(self, device: Optional[str], use_accelerate: bool, accelerator_config: dict):
125+
def _set_accelerator(
126+
self, device: Optional[str], strategy, precision, num_nodes, use_accelerate: bool, accelerator_config: dict
127+
):
123128
if use_accelerate:
124-
self.accelerator = Fabric(accelerator=device, **accelerator_config)
125-
self.device = self.accelerator.device
129+
self._accelerator = Fabric(
130+
accelerator=device, strategy=strategy, precision=precision, num_nodes=num_nodes, **accelerator_config
131+
)
132+
self.device = self._accelerator.device
126133
else:
127134
self.device = device or default_device()
128135

129136
def setup(self, learner: Union[nn.Module, List[nn.Module]], *optimizers):
130-
return self.accelerator.setup(learner, *optimizers)
137+
if not self._accelerator:
138+
return learner, *optimizers
139+
return self._accelerator.setup(learner, *optimizers)
131140

132141
def backward(self, loss: torch.Tensor):
133142
"""model.backward(loss)"""
134-
if not self.accelerator:
143+
if not self._accelerator:
135144
loss.backward()
136145
else:
137-
self.accelerator.backward(loss)
146+
self._accelerator.backward(loss)
138147

139148
def eval(self):
140149
"""Set learner to eval mode for validation"""

gradsflow/models/model.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class Model(BaseModel, DataMixin):
4646
4747
Args:
4848
learner: Trainable model
49+
device: auto | cpu | gpu | mps
50+
precision: Numerical precision value, could be 32 | 16 | "b16"
51+
strategy: Strategy for distributed training (ddp | ddp_spawn | deepspeed | fsdp)
4952
accelerator_config: Accelerator config
5053
"""
5154

@@ -56,13 +59,19 @@ def __init__(
5659
self,
5760
learner: Union[nn.Module, Any],
5861
device: Optional[str] = None,
62+
strategy: Optional[str] = None,
63+
precision: Any = 32,
64+
num_nodes: int = 1,
5965
use_accelerator: bool = True,
6066
accelerator_config: dict = None,
6167
):
6268
accelerator_config = accelerator_config or {}
6369
super().__init__(
6470
learner=learner,
6571
device=device,
72+
strategy=strategy,
73+
precision=precision,
74+
num_nodes=num_nodes,
6675
use_accelerator=use_accelerator,
6776
accelerator_config=accelerator_config,
6877
)
@@ -121,6 +130,7 @@ def compile(
121130
optimizer = optimizer_fn(self.learner.parameters(), lr=learning_rate, **optimizer_config)
122131
if loss:
123132
self.loss = self._get_loss(loss, loss_config)
133+
124134
self.learner, self.optimizer = self.setup(self._learner, optimizer)
125135
self.metrics.compile_metrics(*listify(metrics))
126136
self._compiled = True
@@ -227,7 +237,7 @@ def fit(
227237
```python
228238
autodataset = AutoDataset(train_dataloader, val_dataloader)
229239
model = Model(cnn)
230-
model.compile("crossentropyloss", "adam", learning_rate=1e-3, metrics="accuracy")
240+
model.compile("crossentropyloss", "adam", learning_rate=1e-3)
231241
model.fit(autodataset)
232242
```
233243
Args:
@@ -244,7 +254,7 @@ def fit(
244254
"""
245255
self.assert_compiled()
246256
self.autodataset = autodataset
247-
self.autodataset.setup_data(self.accelerator)
257+
self.autodataset.setup_data(self._accelerator)
248258

249259
if not resume:
250260
self.tracker.reset()

gradsflow/tuner/automodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import ray
2121
from ray import tune
22-
from ray.tune.sample import Domain
22+
from ray.tune.search.sample import Domain
2323
from torch import nn
2424

2525
from gradsflow.data import AutoDataset

tests/models/test_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ def compute_accuracy(*_, **__):
8989

9090

9191
def test_set_accelerator(resnet18):
92-
model = Model(resnet18, accelerator_config={"precision": 16})
92+
model = Model(resnet18, precision=16)
9393
model.compile()
94-
assert model.accelerator
94+
assert model._accelerator
9595

9696

9797
def test_save_model(tmp_path, resnet18, cnn_model):

0 commit comments

Comments
 (0)