@@ -162,7 +162,7 @@ def __init__(
162
162
model , pretrained = load_pretrained , drop_tokens = drop_tokens , device_type = device_type
163
163
)
164
164
if replace_last_layer :
165
- model .head = torch .nn .Linear (model .head .in_features , nb_classes )
165
+ model .head = torch .nn .Linear (model .head .in_features , nb_classes ) # type: ignore
166
166
if isinstance (optimizer , type ):
167
167
if optimizer_params is not None :
168
168
optimizer = optimizer (model .parameters (), ** optimizer_params )
@@ -181,9 +181,9 @@ def __init__(
181
181
model = timm .create_model (
182
182
pretrained_cfg ["architecture" ], drop_tokens = drop_tokens , device_type = device_type
183
183
)
184
- model .load_state_dict (supplied_state_dict )
184
+ model .load_state_dict (supplied_state_dict ) # type: ignore
185
185
if replace_last_layer :
186
- model .head = torch .nn .Linear (model .head .in_features , nb_classes )
186
+ model .head = torch .nn .Linear (model .head .in_features , nb_classes ) # type: ignore
187
187
188
188
if optimizer is not None :
189
189
if not isinstance (optimizer , torch .optim .Optimizer ):
@@ -193,10 +193,10 @@ def __init__(
193
193
opt_state_dict = optimizer .state_dict ()
194
194
if isinstance (optimizer , torch .optim .Adam ):
195
195
logging .info ("Converting Adam Optimiser" )
196
- converted_optimizer = torch .optim .Adam (model .parameters (), lr = 1e-4 )
196
+ converted_optimizer = torch .optim .Adam (model .parameters (), lr = 1e-4 ) # type: ignore
197
197
elif isinstance (optimizer , torch .optim .SGD ):
198
198
logging .info ("Converting SGD Optimiser" )
199
- converted_optimizer = torch .optim .SGD (model .parameters (), lr = 1e-4 )
199
+ converted_optimizer = torch .optim .SGD (model .parameters (), lr = 1e-4 ) # type: ignore
200
200
else :
201
201
raise ValueError ("Optimiser not supported for conversion" )
202
202
converted_optimizer .load_state_dict (opt_state_dict )
0 commit comments