Skip to content

Commit 377923a

Browse files
author
Beat Buesser
committed
Fix LGTM warnings
Signed-off-by: Beat Buesser <[email protected]>
1 parent fd59c73 commit 377923a

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

art/estimators/pytorch.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(self, device_type: str = "gpu", **kwargs) -> None:
5252
be divided by the second one.
5353
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
5454
"""
55-
import torch
55+
import torch # lgtm [py/repeated-import]
5656

5757
preprocessing = kwargs.get("preprocessing")
5858
if isinstance(preprocessing, tuple):
@@ -150,7 +150,7 @@ def _apply_preprocessing(self, x, y, fit: bool = False, no_grad=True) -> Tuple[A
150150
:return: Tuple of `x` and `y` after applying the defences and standardisation.
151151
:rtype: Format as expected by the `model`
152152
"""
153-
import torch
153+
import torch # lgtm [py/repeated-import]
154154

155155
from art.preprocessing.standardisation_mean_std.standardisation_mean_std import StandardisationMeanStd
156156
from art.preprocessing.standardisation_mean_std.standardisation_mean_std_pytorch import (
@@ -227,7 +227,7 @@ def _apply_preprocessing_gradient(self, x, gradients, fit=False):
227227
:return: Gradients after backward pass through preprocessing defences.
228228
:rtype: Format as expected by the `model`
229229
"""
230-
import torch
230+
import torch # lgtm [py/repeated-import]
231231

232232
from art.preprocessing.standardisation_mean_std.standardisation_mean_std import StandardisationMeanStd
233233
from art.preprocessing.standardisation_mean_std.standardisation_mean_std_pytorch import (
@@ -290,9 +290,9 @@ def _set_layer(self, train: bool, layerinfo: List["torch.nn.modules.Module"]) ->
290290
:param train: False for evaluation mode.
291291
:param layerinfo: List of module types.
292292
"""
293-
from torch import nn
293+
import torch # lgtm [py/repeated-import]
294294

295-
assert all([issubclass(l, nn.modules.Module) for l in layerinfo])
295+
assert all([issubclass(l, torch.nn.modules.Module) for l in layerinfo])
296296

297297
def set_train(layer, layerinfo=layerinfo):
298298
"Set layer into training mode if instance of `layerinfo`."
@@ -315,16 +315,16 @@ def set_dropout(self, train: bool) -> None:
315315
316316
:param train: False for evaluation mode.
317317
"""
318-
from torch import nn
318+
import torch # lgtm [py/repeated-import]
319319

320-
self._set_layer(train=train, layerinfo=[nn.modules.dropout._DropoutNd])
320+
self._set_layer(train=train, layerinfo=[torch.nn.modules.dropout._DropoutNd])
321321

322322
def set_batchnorm(self, train: bool) -> None:
323323
"""
324324
Set all batch normalization layers into train or eval mode.
325325
326326
:param train: False for evaluation mode.
327327
"""
328-
from torch import nn
328+
import torch # lgtm [py/repeated-import]
329329

330-
self._set_layer(train=train, layerinfo=[nn.modules.batchnorm._BatchNorm])
330+
self._set_layer(train=train, layerinfo=[torch.nn.modules.batchnorm._BatchNorm])

0 commit comments

Comments
 (0)