Skip to content

Commit 9ea2a91

Browse files
committed
899 updated MembershipInferenceBlackBox#infer to have CUDA support
Signed-off-by: Evan Sakmar <[email protected]>
1 parent cc60b1b commit 9ea2a91

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

art/attacks/inference/membership_inference/black_box.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,13 +198,7 @@ def fit(self, x: np.ndarray, y: np.ndarray, test_x: np.ndarray, test_y: np.ndarr
198198
import torch.nn as nn # lgtm [py/repeated-import]
199199
import torch.optim as optim # lgtm [py/repeated-import]
200200
from torch.utils.data import DataLoader # lgtm [py/repeated-import]
201-
202-
use_cuda = torch.cuda.is_available()
203-
204-
def to_cuda(x):
205-
if use_cuda:
206-
x = x.cuda()
207-
return x
201+
from art.utils import to_cuda
208202

209203
loss_fn = nn.BCELoss()
210204
optimizer = optim.Adam(self.attack_model.parameters(), lr=self.learning_rate)
@@ -261,14 +255,18 @@ def infer(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.n
261255
if self.default_model and self.attack_model_type == "nn":
262256
import torch # lgtm [py/repeated-import]
263257
from torch.utils.data import DataLoader # lgtm [py/repeated-import]
258+
from art.utils import to_cuda, from_cuda
264259

265260
self.attack_model.eval()
266261
inferred = None
267262
test_set = self._get_attack_dataset(f_1=features, f_2=y)
268263
test_loader = DataLoader(test_set, batch_size=self.batch_size, shuffle=True, num_workers=0)
269264
for input1, input2, _ in test_loader:
265+
input1, input2 = to_cuda(input1), to_cuda(input2)
270266
outputs = self.attack_model(input1, input2)
271267
predicted = torch.round(outputs)
268+
predicted = from_cuda(predicted)
269+
272270
if inferred is None:
273271
inferred = predicted.detach().numpy()
274272
else:

art/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from scipy.special import gammainc
3737
import six
3838
from tqdm.auto import tqdm
39+
from torch import Tensor
3940

4041
from art import config
4142

@@ -1235,3 +1236,34 @@ def pad_sequence_input(x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
12351236
x_padded[i, : len(x_i)] = x_i
12361237
x_mask[i, : len(x_i)] = 1
12371238
return x_padded, x_mask
1239+
1240+
1241+
# -------------------------------------------------------------------------------------------------------- CUDA SUPPORT
1242+
1243+
1244+
def to_cuda(x: Tensor) -> Tensor:
1245+
"""
1246+
Move the tensor from the CPU to the GPU if a GPU is available.
1247+
1248+
:param x: CPU Tensor to move to GPU if available.
1249+
:return: The CPU Tensor moved to a GPU Tensor.
1250+
"""
1251+
from torch.cuda import is_available
1252+
use_cuda = is_available()
1253+
if use_cuda:
1254+
x = x.cuda()
1255+
return x
1256+
1257+
1258+
def from_cuda(x: Tensor) -> Tensor:
1259+
"""
1260+
Move the tensor from the GPU to the CPU if a GPU is available.
1261+
1262+
:param x: GPU Tensor to move to CPU if available.
1263+
:return: The GPU Tensor moved to a CPU Tensor.
1264+
"""
1265+
from torch.cuda import is_available
1266+
use_cuda = is_available()
1267+
if use_cuda:
1268+
x = x.cpu()
1269+
return x

0 commit comments

Comments
 (0)