Skip to content

Commit bfcca90

Browse files
committed
Fix mypy
Signed-off-by: Beat Buesser <[email protected]>
1 parent 61fc9fb commit bfcca90

File tree

1 file changed

+5
-5
lines changed
  • art/estimators/certification/derandomized_smoothing

1 file changed

+5
-5
lines changed

art/estimators/certification/derandomized_smoothing/pytorch.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def __init__(
162162
model, pretrained=load_pretrained, drop_tokens=drop_tokens, device_type=device_type
163163
)
164164
if replace_last_layer:
165-
model.head = torch.nn.Linear(model.head.in_features, nb_classes)
165+
model.head = torch.nn.Linear(model.head.in_features, nb_classes) # type: ignore
166166
if isinstance(optimizer, type):
167167
if optimizer_params is not None:
168168
optimizer = optimizer(model.parameters(), **optimizer_params)
@@ -181,9 +181,9 @@ def __init__(
181181
model = timm.create_model(
182182
pretrained_cfg["architecture"], drop_tokens=drop_tokens, device_type=device_type
183183
)
184-
model.load_state_dict(supplied_state_dict)
184+
model.load_state_dict(supplied_state_dict) # type: ignore
185185
if replace_last_layer:
186-
model.head = torch.nn.Linear(model.head.in_features, nb_classes)
186+
model.head = torch.nn.Linear(model.head.in_features, nb_classes) # type: ignore
187187

188188
if optimizer is not None:
189189
if not isinstance(optimizer, torch.optim.Optimizer):
@@ -193,10 +193,10 @@ def __init__(
193193
opt_state_dict = optimizer.state_dict()
194194
if isinstance(optimizer, torch.optim.Adam):
195195
logging.info("Converting Adam Optimiser")
196-
converted_optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
196+
converted_optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) # type: ignore
197197
elif isinstance(optimizer, torch.optim.SGD):
198198
logging.info("Converting SGD Optimiser")
199-
converted_optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
199+
converted_optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) # type: ignore
200200
else:
201201
raise ValueError("Optimiser not supported for conversion")
202202
converted_optimizer.load_state_dict(opt_state_dict)

0 commit comments

Comments
 (0)