11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import logging
14
15
import os
15
16
from dataclasses import dataclass
16
17
from typing import Any , Callable , List , Optional , Union
17
18
18
19
import smart_open
19
20
import torch
20
- from accelerate import Accelerator
21
+ from lightning . fabric import Fabric
21
22
from torch import nn
22
23
23
24
from gradsflow .models .tracker import Tracker
31
32
class Base :
32
33
TEST = os .environ .get ("GF_CI" , "false" ).lower () == "true"
33
34
34
- learner : Union [nn .Module , Any ]
35
+ _learner : Union [nn .Module , Any ]
35
36
optimizer : torch .optim .Optimizer = None
36
37
loss : Callable = None
37
38
_compiled : bool = False
@@ -43,6 +44,14 @@ def __init__(self):
43
44
def __call__ (self , x ):
44
45
return self .forward (x )
45
46
47
+ @property
48
+ def learner (self ) -> Union [nn .Module , Any ]:
49
+ return self ._learner
50
+
51
+ @learner .setter
52
+ def learner (self , learner ):
53
+ self ._learner = learner
54
+
46
55
@staticmethod
47
56
def _get_loss (loss : Union [str , Callable ], loss_config : dict ) -> Optional [Callable ]:
48
57
loss_fn = None
@@ -101,43 +110,24 @@ class BaseModel(Base):
101
110
def __init__ (
102
111
self ,
103
112
learner : Union [nn .Module , Any ],
104
- device : Optional [str ] = None ,
105
- use_accelerate : bool = True ,
113
+ device : Optional [str ] = "auto" ,
114
+ use_accelerator : bool = True ,
106
115
accelerator_config : dict = None ,
107
116
):
108
117
self .accelerator = None
109
118
super ().__init__ ()
110
- self ._set_accelerator (device , use_accelerate , accelerator_config )
111
- self .learner = self . prepare_model ( learner )
119
+ self ._set_accelerator (device , use_accelerator , accelerator_config )
120
+ self ._learner = learner
112
121
113
122
def _set_accelerator (self , device : Optional [str ], use_accelerate : bool , accelerator_config : dict ):
114
123
if use_accelerate :
115
- self .accelerator = Accelerator ( cpu = ( device == "cpu" ) , ** accelerator_config )
124
+ self .accelerator = Fabric ( accelerator = device , ** accelerator_config )
116
125
self .device = self .accelerator .device
117
126
else :
118
127
self .device = device or default_device ()
119
128
120
- def prepare_model (self , learner : Union [nn .Module , List [nn .Module ]]):
121
- """Inplace ops for preparing model via HF Accelerator. Automatically sends to device."""
122
- if not self .accelerator :
123
- learner = learner .to (self .device )
124
- return learner
125
- if isinstance (learner , (list , tuple )):
126
- self .learner = list (map (self .accelerator .prepare_model , learner ))
127
- elif isinstance (learner , nn .Module ):
128
- self .learner = self .accelerator .prepare_model (learner )
129
- else :
130
- raise NotImplementedError (
131
- f"prepare_model is not implemented for model of type { type (learner )} ! Please implement prepare_model "
132
- f"or raise an issue."
133
- )
134
-
135
- return self .learner
136
-
137
- def prepare_optimizer (self , optimizer ) -> torch .optim .Optimizer :
138
- if not self .accelerator :
139
- return optimizer
140
- return self .accelerator .prepare_optimizer (optimizer )
129
+ def setup (self , learner : Union [nn .Module , List [nn .Module ]], * optimizers ):
130
+ return self .accelerator .setup (learner , * optimizers )
141
131
142
132
def backward (self , loss : torch .Tensor ):
143
133
"""model.backward(loss)"""
0 commit comments