|
4 | 4 | from taint import adversarial_attack_blackbox |
5 | 5 | from analysis import * |
6 | 6 | from train import train_model_and_save |
| 7 | +import torch |
| 8 | +import tensorflow as tf |
7 | 9 |
|
8 | 10 | def attack_model(args, model, test_ds, save_dir, num_data=10): |
9 | | - # Path to the pickle file that stores the attacker object |
10 | | - pickle_path = os.path.join(save_dir, 'attacker.pkl') |
| 11 | + # Get the labels by iterating through a batch from the test_ds |
| 12 | + first_batch = next(iter(test_ds)) # Get the first batch |
| 13 | + images, labels = first_batch # Unpack the images and labels from the first batch |
11 | 14 |
|
12 | | - # Check if the adversarial attack has already been performed (if pickle exists) |
13 | | - if os.path.exists(pickle_path): |
14 | | - # If pickle exists, load the attacker from the file |
15 | | - with open(pickle_path, 'rb') as f: |
16 | | - attacker = pickle.load(f) |
17 | | - print(f"Loaded attacker from {pickle_path}") |
18 | | - else: |
19 | | - # If pickle does not exist, run the attack and save the attacker |
20 | | - print("Running adversarial attack...") |
21 | | - |
22 | | - # First, identify unique outputs in the dataset |
23 | | - unique_outputs = set(test_ds.labels) # assuming `test_ds.labels` contains the true labels |
24 | | - |
25 | | - for output in unique_outputs: |
26 | | - # Find the first 10 instances of this output in the dataset |
27 | | - instances = [i for i, label in enumerate(test_ds.labels) if label == output][:num_data] |
28 | | - # Perform the attack on each of these instances |
29 | | - for image_index in instances: |
30 | | - adversarial_attack_blackbox( |
31 | | - model, test_ds, image_index=image_index, output_dir=save_dir, |
32 | | - num_iterations=args.iterations, num_particles=args.particles |
33 | | - ) |
34 | | - print(f"Attacked image {image_index} with label {output}") |
| 15 | + # Check if labels are a TensorFlow tensor or PyTorch tensor |
| 16 | + if isinstance(labels, tf.Tensor): |
| 17 | + # If using TensorFlow, convert labels to class indices (from one-hot encoded) |
| 18 | + labels = tf.argmax(labels, axis=1).numpy() # Get class indices from one-hot encoded labels |
| 19 | + elif isinstance(labels, torch.Tensor): |
| 20 | + # If using PyTorch, convert labels to class indices (from one-hot encoded) |
| 21 | + labels = torch.argmax(labels, dim=1).cpu().numpy() # Get class indices from one-hot encoded labels |
| 22 | + |
| 23 | + # Convert labels to a set of unique outputs |
| 24 | + unique_outputs = set(labels) # Convert to a Python set for unique labels |
| 25 | + |
| 26 | + # Continue with the rest of the attack logic |
| 27 | + for output in unique_outputs: |
| 28 | + instances = [i for i, label in enumerate(labels) if label == output][:num_data] # Select `num_data` instances with the current output label |
| 29 | + |
| 30 | + for image_index in instances: |
| 31 | + # Create a subdirectory for each image_index and its original output label |
| 32 | + sub_dir = os.path.join(save_dir, f'image_{image_index}_label_{output}') |
| 33 | + |
| 34 | + # Ensure the directory exists |
| 35 | + os.makedirs(sub_dir, exist_ok=True) |
| 36 | + |
| 37 | + # Correct dynamic pickle filename to include the original and target class |
| 38 | + pickle_filename = f'attacker_{image_index}_{output}.pkl' |
| 39 | + pickle_path = os.path.join(sub_dir, pickle_filename) |
| 40 | + |
| 41 | + # Check if the attacker pickle already exists for this image_index and output |
| 42 | + if os.path.exists(pickle_path): |
| 43 | + with open(pickle_path, 'rb') as f: |
| 44 | + attacker = pickle.load(f) |
| 45 | + print(f"Loaded attacker for image {image_index} with label {output} from {pickle_path}") |
| 46 | + else: |
| 47 | + print(f"Running adversarial attack for image {image_index} with label {output}...") |
| 48 | + |
| 49 | + # For the current `output`, target all other classes |
| 50 | + for target_output in unique_outputs: |
| 51 | + if target_output != output: # We want to target all other outputs |
| 52 | + for _ in range(num_data): # Attack the target output `num_data` times |
| 53 | + target_sub_dir = os.path.join(sub_dir, f'target_{target_output}') |
| 54 | + os.makedirs(target_sub_dir, exist_ok=True) # Create a subdir for each target class |
| 55 | + |
| 56 | + # Correct dynamic pickle filename to include the original and target class |
| 57 | + target_pickle_filename = f'attacker_{image_index}_{output}_to_{target_output}.pkl' |
| 58 | + target_pickle_path = os.path.join(target_sub_dir, target_pickle_filename) |
| 59 | + |
| 60 | + # Perform the adversarial attack targeting `target_output` |
| 61 | + attacker = adversarial_attack_blackbox( |
| 62 | + model=model, |
| 63 | + dataset=test_ds, |
| 64 | + image_index=image_index, |
| 65 | + output_dir=target_sub_dir, |
| 66 | + num_iterations=args.iterations, |
| 67 | + num_particles=args.particles, |
| 68 | + target_class=target_output # Specify the target class for the attack |
| 69 | + ) |
| 70 | + print(f"Adversarial attack completed for image {image_index} targeting class {target_output}") |
| 71 | + |
| 72 | + # After performing the attack, save the attacker object to a pickle file |
| 73 | + with open(target_pickle_path, 'wb') as f: |
| 74 | + pickle.dump(attacker, f) |
| 75 | + print(f"Saved attacker for image {image_index} with label {output} targeting {target_output} to {target_pickle_path}") |
35 | 76 |
|
36 | 77 | def main(): |
37 | 78 | # Command-line arguments |
|
0 commit comments