|
| 1 | +""" |
| 2 | +================================================================================ |
| 3 | +WGAN-GP Conditional Generator for DDoS Flow Features (CICDDoS2019) |
| 4 | +================================================================================ |
| 5 | +Author : Breno Farias da Silva |
| 6 | +Created : 2025-11-21 |
| 7 | +Description : |
| 8 | + Implements a DRCGAN-like conditional Wasserstein GAN with Gradient Penalty |
| 9 | + for tabular flow features from CICDDoS2019 dataset. This module enables |
| 10 | + generating synthetic network flow data conditioned on attack labels. |
| 11 | +
|
| 12 | + Key features include: |
| 13 | + - CSV loader with automatic scaling and label encoding |
| 14 | + - Conditional generator with residual blocks (DRC-style architecture) |
| 15 | + - MLP discriminator (critic) for Wasserstein loss |
| 16 | + - WGAN-GP training loop with gradient penalty for stability |
| 17 | + - Checkpoint saving and synthetic sample generation to CSV |
| 18 | + - Support for multi-class conditional generation |
| 19 | +
|
| 20 | +Usage: |
| 21 | + 1. Prepare a CSV file with network flow features and labels. |
| 22 | + 2. Train the model using the train mode: |
| 23 | + $ python wgangp.py --mode train --csv_path data.csv --epochs 60 |
| 24 | + 3. Generate synthetic samples using a trained checkpoint: |
| 25 | + $ python wgangp.py --mode gen --checkpoint outputs/generator_epoch60.pt --n_samples 1000 |
| 26 | +
|
| 27 | +Outputs: |
| 28 | + - outputs/generator_epoch*.pt — Saved generator checkpoints with metadata |
| 29 | + - outputs/discriminator_epoch*.pt — Saved discriminator checkpoints |
| 30 | + - generated.csv — Generated synthetic flow samples (via --mode gen) |
| 31 | +
|
| 32 | +TODOs: |
| 33 | + - Implement learning rate scheduling for better convergence |
| 34 | + - Add support for different activation functions |
| 35 | + - Extend feature importance analysis for generated data |
| 36 | + - Add data quality metrics (statistical distance, mode coverage) |
| 37 | + - Implement multi-GPU training support |
| 38 | +
|
| 39 | +Dependencies: |
| 40 | + - Python >= 3.9 |
| 41 | + - torch >= 1.9.0 |
| 42 | + - numpy |
| 43 | + - pandas |
| 44 | + - scikit-learn |
| 45 | +
|
| 46 | +Assumptions & Notes: |
| 47 | + - CSV should contain feature columns and a label column |
| 48 | + - Features are automatically scaled using StandardScaler |
| 49 | + - Labels are encoded via LabelEncoder (categorical to integer) |
| 50 | + - Output features are inverse-transformed to original scale |
| 51 | + - CUDA is used if available; use --force_cpu to disable |
| 52 | +
|
| 53 | +================================================================================ |
| 54 | +""" |
| 55 | + |
| 56 | +import argparse # For CLI argument parsing |
| 57 | +import atexit # For playing a sound when the program finishes |
| 58 | +import numpy as np # Numerical operations |
| 59 | +import os # For running a command in the terminal |
| 60 | +import pandas as pd # For CSV handling |
| 61 | +import platform # For getting the operating system name |
| 62 | +import random # For reproducibility |
| 63 | +import torch # PyTorch core |
| 64 | +import torch.nn as nn # Neural network modules |
| 65 | +from colorama import Style # For coloring the terminal |
| 66 | +from sklearn.preprocessing import StandardScaler, LabelEncoder # For data preprocessing |
| 67 | +from torch import autograd # For gradient penalty |
| 68 | +from torch.utils.data import Dataset, DataLoader # Dataset and DataLoader |
| 69 | +from typing import Any, List, Optional # For Any type hint |
| 70 | + |
| 71 | +# Macros: |
| 72 | +class BackgroundColors: # Colors for the terminal |
| 73 | + CYAN = "\033[96m" # Cyan |
| 74 | + GREEN = "\033[92m" # Green |
| 75 | + YELLOW = "\033[93m" # Yellow |
| 76 | + RED = "\033[91m" # Red |
| 77 | + BOLD = "\033[1m" # Bold |
| 78 | + UNDERLINE = "\033[4m" # Underline |
| 79 | + CLEAR_TERMINAL = "\033[H\033[J" # Clear the terminal |
| 80 | + |
| 81 | +# Execution Constants: |
| 82 | +VERBOSE = False # Set to True to output verbose messages |
| 83 | + |
| 84 | +# Sound Constants: |
| 85 | +SOUND_COMMANDS = {"Darwin": "afplay", "Linux": "aplay", "Windows": "start"} # The commands to play a sound for each operating system |
| 86 | +SOUND_FILE = "./.assets/Sounds/NotificationSound.wav" # The path to the sound file |
| 87 | + |
| 88 | +# RUN_FUNCTIONS: |
| 89 | +RUN_FUNCTIONS = { |
| 90 | + "Play Sound": True, # Set to True to play a sound when the program finishes |
| 91 | +} |
| 92 | + |
| 93 | +# Functions Definitions: |
| 94 | + |
| 95 | +def verbose_output(true_string="", false_string=""): |
| 96 | + """ |
| 97 | + Outputs a message if the VERBOSE constant is set to True. |
| 98 | +
|
| 99 | + :param true_string: The string to be outputted if the VERBOSE constant is set to True. |
| 100 | + :param false_string: The string to be outputted if the VERBOSE constant is set to False. |
| 101 | + :return: None |
| 102 | + """ |
| 103 | + |
| 104 | + if VERBOSE and true_string != "": # If the VERBOSE constant is set to True and the true_string is set |
| 105 | + print(true_string) # Output the true statement string |
| 106 | + elif false_string != "": # If the false_string is set |
| 107 | + print(false_string) # Output the false statement string |
| 108 | + |
| 109 | +def verify_filepath_exists(filepath): |
| 110 | + """ |
| 111 | + Verify if a file or folder exists at the specified path. |
| 112 | +
|
| 113 | + :param filepath: Path to the file or folder |
| 114 | + :return: True if the file or folder exists, False otherwise |
| 115 | + """ |
| 116 | + |
| 117 | + verbose_output(f"{BackgroundColors.GREEN}Verifying if the file or folder exists at the path: {BackgroundColors.CYAN}{filepath}{Style.RESET_ALL}") # Output the verbose message |
| 118 | + |
| 119 | + return os.path.exists(filepath) # Return True if the file or folder exists, False otherwise |
| 120 | + |
| 121 | +def play_sound(): |
| 122 | + """ |
| 123 | + Plays a sound when the program finishes and skips if the operating system is Windows. |
| 124 | +
|
| 125 | + :param: None |
| 126 | + :return: None |
| 127 | + """ |
| 128 | + |
| 129 | + current_os = platform.system() # Get the current operating system |
| 130 | + if current_os == "Windows": # If the current operating system is Windows |
| 131 | + return # Do nothing |
| 132 | + |
| 133 | + if verify_filepath_exists(SOUND_FILE): # If the sound file exists |
| 134 | + if current_os in SOUND_COMMANDS: # If the platform.system() is in the SOUND_COMMANDS dictionary |
| 135 | + os.system(f"{SOUND_COMMANDS[current_os]} {SOUND_FILE}") # Play the sound |
| 136 | + else: # If the platform.system() is not in the SOUND_COMMANDS dictionary |
| 137 | + print(f"{BackgroundColors.RED}The {BackgroundColors.CYAN}{current_os}{BackgroundColors.RED} is not in the {BackgroundColors.CYAN}SOUND_COMMANDS dictionary{BackgroundColors.RED}. Please add it!{Style.RESET_ALL}") |
| 138 | + else: # If the sound file does not exist |
| 139 | + print(f"{BackgroundColors.RED}Sound file {BackgroundColors.CYAN}{SOUND_FILE}{BackgroundColors.RED} not found. Make sure the file exists.{Style.RESET_ALL}") |
| 140 | + |
| 141 | +def main(): |
| 142 | + """ |
| 143 | + Main function. |
| 144 | +
|
| 145 | + :param: None |
| 146 | + :return: None |
| 147 | + """ |
| 148 | + |
| 149 | + print(f"{BackgroundColors.CLEAR_TERMINAL}{BackgroundColors.BOLD}{BackgroundColors.GREEN}Welcome to the {BackgroundColors.CYAN}WGAN-GP Data Augmentation{BackgroundColors.GREEN} program!{Style.RESET_ALL}", end="\n\n") # Output the welcome message |
| 150 | + |
| 151 | + args = parse_args() # Parse command-line arguments |
| 152 | + if args.mode == "train": # If training mode is selected |
| 153 | + assert args.csv_path is not None, "Training requires --csv_path" # Ensure CSV path is provided |
| 154 | + train(args) # Run training function |
| 155 | + elif args.mode == "gen": # If generation mode is selected |
| 156 | + assert args.checkpoint is not None, "Generation requires --checkpoint" # Ensure checkpoint is provided |
| 157 | + generate(args) # Run generation function |
| 158 | + |
| 159 | + print(f"\n{BackgroundColors.BOLD}{BackgroundColors.GREEN}Program finished.{Style.RESET_ALL}") # Output the end of the program message |
| 160 | + |
| 161 | + atexit.register(play_sound) if RUN_FUNCTIONS["Play Sound"] else None # Register the play_sound function to be called when the program finishes |
| 162 | + |
| 163 | +if __name__ == "__main__": |
| 164 | + """ |
| 165 | + This is the standard boilerplate that calls the main() function. |
| 166 | +
|
| 167 | + :return: None |
| 168 | + """ |
| 169 | + |
| 170 | + main() # Call the main function |
0 commit comments