@@ -162,7 +162,7 @@ def __init__(
162162 model , pretrained = load_pretrained , drop_tokens = drop_tokens , device_type = device_type
163163 )
164164 if replace_last_layer :
165- model .head = torch .nn .Linear (model .head .in_features , nb_classes )
165+ model .head = torch .nn .Linear (model .head .in_features , nb_classes ) # type: ignore
166166 if isinstance (optimizer , type ):
167167 if optimizer_params is not None :
168168 optimizer = optimizer (model .parameters (), ** optimizer_params )
@@ -181,9 +181,9 @@ def __init__(
181181 model = timm .create_model (
182182 pretrained_cfg ["architecture" ], drop_tokens = drop_tokens , device_type = device_type
183183 )
184- model .load_state_dict (supplied_state_dict )
184+ model .load_state_dict (supplied_state_dict ) # type: ignore
185185 if replace_last_layer :
186- model .head = torch .nn .Linear (model .head .in_features , nb_classes )
186+ model .head = torch .nn .Linear (model .head .in_features , nb_classes ) # type: ignore
187187
188188 if optimizer is not None :
189189 if not isinstance (optimizer , torch .optim .Optimizer ):
@@ -193,10 +193,10 @@ def __init__(
193193 opt_state_dict = optimizer .state_dict ()
194194 if isinstance (optimizer , torch .optim .Adam ):
195195 logging .info ("Converting Adam Optimiser" )
196- converted_optimizer = torch .optim .Adam (model .parameters (), lr = 1e-4 )
196+ converted_optimizer = torch .optim .Adam (model .parameters (), lr = 1e-4 ) # type: ignore
197197 elif isinstance (optimizer , torch .optim .SGD ):
198198 logging .info ("Converting SGD Optimiser" )
199- converted_optimizer = torch .optim .SGD (model .parameters (), lr = 1e-4 )
199+ converted_optimizer = torch .optim .SGD (model .parameters (), lr = 1e-4 ) # type: ignore
200200 else :
201201 raise ValueError ("Optimiser not supported for conversion" )
202202 converted_optimizer .load_state_dict (opt_state_dict )
0 commit comments