Skip to content

Commit 204010d

Browse files
authored
refactor Accelerator with Fabric (#192)
* refactor Accelerator with Fabric * update * fix * remove flash * fixes * update * update * update
1 parent 469caf6 commit 204010d

File tree

14 files changed

+83
-80
lines changed

14 files changed

+83
-80
lines changed

examples/src/models/hello_world.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torchvision
1616
from timm import create_model
1717
from torch.utils.data import DataLoader
18+
from torchmetrics.classification import MulticlassAccuracy
1819
from torchvision import transforms as T
1920

2021
from gradsflow import AutoDataset, Model
@@ -55,5 +56,5 @@
5556

5657
model = Model(cnn)
5758

58-
model.compile("crossentropyloss", "adam", metrics=["accuracy"])
59+
model.compile("crossentropyloss", "adam", metrics=[MulticlassAccuracy(autodataset.num_classes)])
5960
model.fit(autodataset, max_epochs=10, steps_per_epoch=10, callbacks=cbs)

gradsflow/autotasks/engine/backend.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525
from gradsflow.utility.imports import is_installed
2626

2727
if typing.TYPE_CHECKING:
28-
import pytorch_lightning as pl
28+
import lightning as L
2929

30-
if is_installed("pytorch_lightning"):
30+
if is_installed("lightning-flash"):
3131
from flash import Task
3232
from flash import Trainer as FlashTrainer
33-
from pytorch_lightning import Trainer as PLTrainer
33+
from lightning import Trainer as PLTrainer
3434
else:
3535
FlashTrainer = None
3636
PLTrainer = None
@@ -40,10 +40,10 @@
4040

4141
class BackendType(Enum):
4242
# Remove torch
43-
pl = "pl"
43+
lightning = "lightning"
4444
gf = "gf"
4545
torch = "gf"
46-
default = "pl"
46+
default = "lightning"
4747

4848

4949
class Backend:
@@ -90,7 +90,7 @@ def _lightning_objective(
9090

9191
trainer_cls = FlashTrainer if isinstance(model, Task) else PLTrainer
9292

93-
trainer: "pl.Trainer" = trainer_cls(
93+
trainer: "L.Trainer" = trainer_cls(
9494
logger=True,
9595
accelerator="auto",
9696
devices="auto",
@@ -122,7 +122,7 @@ def optimization_objective(
122122
trainer_config dict: configurations passed directly to Lightning Trainer.
123123
gpu Optional[float]: GPU per trial
124124
"""
125-
if self.backend_type == BackendType.pl.value:
125+
if self.backend_type == BackendType.lightning.value:
126126
return self._lightning_objective(config, trainer_config=trainer_config, gpu=gpu, finetune=finetune)
127127

128128
if self.backend_type in (BackendType.gf.value,):

gradsflow/data/autodata.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import warnings
1616
from typing import Callable, Dict, Optional, Union
1717

18-
from accelerate import Accelerator
18+
from lightning.fabric import Fabric
1919
from torch.utils.data import DataLoader, Dataset
2020

2121
from gradsflow.data.base import BaseAutoDataset
@@ -131,13 +131,13 @@ def device_setup_status(self, value: bool = True):
131131
logger.debug("setting device setup=True")
132132
self.meta["device_setup_status"] = value
133133

134-
def prepare_data(self, accelerator: Accelerator) -> None:
134+
def setup_data(self, accelerator: Fabric) -> None:
135135
if accelerator is None:
136136
warnings.warn("Accelerator is None, skipped data preparation!")
137137
return
138-
self._train_dataloader = accelerator.prepare_data_loader(self._train_dataloader)
138+
self._train_dataloader = accelerator.setup_dataloaders(self._train_dataloader)
139139
if self._val_dataloader:
140-
self._val_dataloader = accelerator.prepare_data_loader(self._val_dataloader)
140+
self._val_dataloader = accelerator.setup_dataloaders(self._val_dataloader)
141141
self.device_setup_status = True
142142
self.device = accelerator.device
143143

gradsflow/models/base.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import logging
1415
import os
1516
from dataclasses import dataclass
1617
from typing import Any, Callable, List, Optional, Union
1718

1819
import smart_open
1920
import torch
20-
from accelerate import Accelerator
21+
from lightning.fabric import Fabric
2122
from torch import nn
2223

2324
from gradsflow.models.tracker import Tracker
@@ -31,7 +32,7 @@
3132
class Base:
3233
TEST = os.environ.get("GF_CI", "false").lower() == "true"
3334

34-
learner: Union[nn.Module, Any]
35+
_learner: Union[nn.Module, Any]
3536
optimizer: torch.optim.Optimizer = None
3637
loss: Callable = None
3738
_compiled: bool = False
@@ -43,6 +44,14 @@ def __init__(self):
4344
def __call__(self, x):
4445
return self.forward(x)
4546

47+
@property
48+
def learner(self) -> Union[nn.Module, Any]:
49+
return self._learner
50+
51+
@learner.setter
52+
def learner(self, learner):
53+
self._learner = learner
54+
4655
@staticmethod
4756
def _get_loss(loss: Union[str, Callable], loss_config: dict) -> Optional[Callable]:
4857
loss_fn = None
@@ -101,43 +110,24 @@ class BaseModel(Base):
101110
def __init__(
102111
self,
103112
learner: Union[nn.Module, Any],
104-
device: Optional[str] = None,
105-
use_accelerate: bool = True,
113+
device: Optional[str] = "auto",
114+
use_accelerator: bool = True,
106115
accelerator_config: dict = None,
107116
):
108117
self.accelerator = None
109118
super().__init__()
110-
self._set_accelerator(device, use_accelerate, accelerator_config)
111-
self.learner = self.prepare_model(learner)
119+
self._set_accelerator(device, use_accelerator, accelerator_config)
120+
self._learner = learner
112121

113122
def _set_accelerator(self, device: Optional[str], use_accelerate: bool, accelerator_config: dict):
114123
if use_accelerate:
115-
self.accelerator = Accelerator(cpu=(device == "cpu"), **accelerator_config)
124+
self.accelerator = Fabric(accelerator=device, **accelerator_config)
116125
self.device = self.accelerator.device
117126
else:
118127
self.device = device or default_device()
119128

120-
def prepare_model(self, learner: Union[nn.Module, List[nn.Module]]):
121-
"""Inplace ops for preparing model via HF Accelerator. Automatically sends to device."""
122-
if not self.accelerator:
123-
learner = learner.to(self.device)
124-
return learner
125-
if isinstance(learner, (list, tuple)):
126-
self.learner = list(map(self.accelerator.prepare_model, learner))
127-
elif isinstance(learner, nn.Module):
128-
self.learner = self.accelerator.prepare_model(learner)
129-
else:
130-
raise NotImplementedError(
131-
f"prepare_model is not implemented for model of type {type(learner)}! Please implement prepare_model "
132-
f"or raise an issue."
133-
)
134-
135-
return self.learner
136-
137-
def prepare_optimizer(self, optimizer) -> torch.optim.Optimizer:
138-
if not self.accelerator:
139-
return optimizer
140-
return self.accelerator.prepare_optimizer(optimizer)
129+
def setup(self, learner: Union[nn.Module, List[nn.Module]], *optimizers):
130+
return self.accelerator.setup(learner, *optimizers)
141131

142132
def backward(self, loss: torch.Tensor):
143133
"""model.backward(loss)"""

gradsflow/models/model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class Model(BaseModel, DataMixin):
4646
4747
Args:
4848
learner: Trainable model
49-
accelerator_config: HuggingFace Accelerator config
49+
accelerator_config: Accelerator config
5050
"""
5151

5252
TEST = os.environ.get("GF_CI", "false").lower() == "true"
@@ -56,14 +56,14 @@ def __init__(
5656
self,
5757
learner: Union[nn.Module, Any],
5858
device: Optional[str] = None,
59-
use_accelerate: bool = True,
59+
use_accelerator: bool = True,
6060
accelerator_config: dict = None,
6161
):
6262
accelerator_config = accelerator_config or {}
6363
super().__init__(
6464
learner=learner,
6565
device=device,
66-
use_accelerate=use_accelerate,
66+
use_accelerator=use_accelerator,
6767
accelerator_config=accelerator_config,
6868
)
6969

@@ -119,9 +119,9 @@ def compile(
119119
if optimizer:
120120
optimizer_fn = self._get_optimizer(optimizer)
121121
optimizer = optimizer_fn(self.learner.parameters(), lr=learning_rate, **optimizer_config)
122-
self.optimizer = self.prepare_optimizer(optimizer)
123122
if loss:
124123
self.loss = self._get_loss(loss, loss_config)
124+
self.learner, self.optimizer = self.setup(self._learner, optimizer)
125125
self.metrics.compile_metrics(*listify(metrics))
126126
self._compiled = True
127127

@@ -244,7 +244,7 @@ def fit(
244244
"""
245245
self.assert_compiled()
246246
self.autodataset = autodataset
247-
self.autodataset.prepare_data(self.accelerator)
247+
self.autodataset.setup_data(self.accelerator)
248248

249249
if not resume:
250250
self.tracker.reset()

gradsflow/utility/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import dataclasses
1515
import inspect
16+
import logging
1617
import os
1718
import re
1819
import sys
@@ -120,7 +121,7 @@ def to_item(data: Any) -> Union[int, float, str, np.ndarray, Dict]:
120121
data = data.detach()
121122
data = data.cpu().numpy()
122123

123-
warnings.warn("to_item didn't convert any value.")
124+
logging.info("to_item didn't convert any value.")
124125
return data
125126

126127

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@ profile = "black"
1111

1212
[tool.black]
1313
line_length = 120
14+
15+
16+
[tool.pytest.ini_options]
17+
norecursedirs = ["tests/autotasks", "tests/tuner"]

setup.cfg

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@ python_requires = >=3.8
2929
install_requires =
3030
torch >=1.13.1
3131
torchvision
32-
ray[default,tune] >=1.8.0
32+
ray[default,tune] >=2.2.0
3333
timm>=0.6.12
3434
rich>=13.3.1
35-
accelerate >=0.5.0
3635
smart_open >=5.1,<=5.2.1
3736
torchmetrics >=0.11.1
37+
lightning >=1.9.2
3838

3939
[options.extras_require]
40-
dev = lightning-flash[image,text] >=0.5.1; codecarbon >=1.2.0; comet_ml; wandb; tensorboard
41-
test = pytest; coverage; pytest-sugar
40+
dev = codecarbon >=1.2.0; wandb; tensorboard
41+
test = pytest; coverage; pytest-sugar; pytest-randomly
4242

4343
[options.packages.find] #optional
4444
exclude=tests, docs, examples

tests/__main__.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import urllib.request
16+
import zipfile
1517
from pathlib import Path
1618

17-
from flash.core.data.utils import download_data
19+
cwd = Path.cwd()
20+
(Path.cwd() / "data").mkdir(exist_ok=True)
1821

19-
cwd = str(Path.cwd())
20-
21-
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", f"{cwd}/data")
22-
23-
download_data(
22+
urllib.request.urlretrieve(
2423
"https://github.com/gradsflow/test-data/archive/refs/tags/cat-dog-v0.zip",
25-
f"{cwd}/data",
24+
f"{cwd}/data/test-cat-dog-v0.zip",
2625
)
26+
27+
with zipfile.ZipFile(f"{cwd}/data/test-cat-dog-v0.zip", "r") as zip_ref:
28+
zip_ref.extractall(f"{cwd}/data/")

tests/autotasks/test_core_automodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_objective(mock_pl_trainer, mock_fl_trainer):
5555
model = AutoModel(
5656
datamodule,
5757
optimization_metric=optimization_metric,
58-
backend=BackendType.pl.value,
58+
backend=BackendType.lightning.value,
5959
)
6060

6161
model.backend.model_builder = MagicMock()

0 commit comments

Comments
 (0)