Skip to content

Commit b757d9f

Browse files
FEAT: Adding the boilerplate for the WGAN-GP Python code
1 parent ab07958 commit b757d9f

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed

wgangp.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
"""
2+
================================================================================
3+
WGAN-GP Conditional Generator for DDoS Flow Features (CICDDoS2019)
4+
================================================================================
5+
Author : Breno Farias da Silva
6+
Created : 2025-11-21
7+
Description :
8+
Implements a DRCGAN-like conditional Wasserstein GAN with Gradient Penalty
9+
for tabular flow features from CICDDoS2019 dataset. This module enables
10+
generating synthetic network flow data conditioned on attack labels.
11+
12+
Key features include:
13+
- CSV loader with automatic scaling and label encoding
14+
- Conditional generator with residual blocks (DRC-style architecture)
15+
- MLP discriminator (critic) for Wasserstein loss
16+
- WGAN-GP training loop with gradient penalty for stability
17+
- Checkpoint saving and synthetic sample generation to CSV
18+
- Support for multi-class conditional generation
19+
20+
Usage:
21+
1. Prepare a CSV file with network flow features and labels.
22+
2. Train the model using the train mode:
23+
$ python wgangp.py --mode train --csv_path data.csv --epochs 60
24+
3. Generate synthetic samples using a trained checkpoint:
25+
$ python wgangp.py --mode gen --checkpoint outputs/generator_epoch60.pt --n_samples 1000
26+
27+
Outputs:
28+
- outputs/generator_epoch*.pt — Saved generator checkpoints with metadata
29+
- outputs/discriminator_epoch*.pt — Saved discriminator checkpoints
30+
- generated.csv — Generated synthetic flow samples (via --mode gen)
31+
32+
TODOs:
33+
- Implement learning rate scheduling for better convergence
34+
- Add support for different activation functions
35+
- Extend feature importance analysis for generated data
36+
- Add data quality metrics (statistical distance, mode coverage)
37+
- Implement multi-GPU training support
38+
39+
Dependencies:
40+
- Python >= 3.9
41+
- torch >= 1.9.0
42+
- numpy
43+
- pandas
44+
- scikit-learn
45+
46+
Assumptions & Notes:
47+
- CSV should contain feature columns and a label column
48+
- Features are automatically scaled using StandardScaler
49+
- Labels are encoded via LabelEncoder (categorical to integer)
50+
- Output features are inverse-transformed to original scale
51+
- CUDA is used if available; use --force_cpu to disable
52+
53+
================================================================================
54+
"""
55+
56+
import argparse # For CLI argument parsing
57+
import atexit # For playing a sound when the program finishes
58+
import numpy as np # Numerical operations
59+
import os # For running a command in the terminal
60+
import pandas as pd # For CSV handling
61+
import platform # For getting the operating system name
62+
import random # For reproducibility
63+
import torch # PyTorch core
64+
import torch.nn as nn # Neural network modules
65+
from colorama import Style # For coloring the terminal
66+
from sklearn.preprocessing import StandardScaler, LabelEncoder # For data preprocessing
67+
from torch import autograd # For gradient penalty
68+
from torch.utils.data import Dataset, DataLoader # Dataset and DataLoader
69+
from typing import Any, List, Optional # For Any type hint
70+
71+
# Macros:
72+
class BackgroundColors: # Colors for the terminal
73+
CYAN = "\033[96m" # Cyan
74+
GREEN = "\033[92m" # Green
75+
YELLOW = "\033[93m" # Yellow
76+
RED = "\033[91m" # Red
77+
BOLD = "\033[1m" # Bold
78+
UNDERLINE = "\033[4m" # Underline
79+
CLEAR_TERMINAL = "\033[H\033[J" # Clear the terminal
80+
81+
# Execution Constants:
82+
VERBOSE = False # Set to True to output verbose messages
83+
84+
# Sound Constants:
85+
SOUND_COMMANDS = {"Darwin": "afplay", "Linux": "aplay", "Windows": "start"} # The commands to play a sound for each operating system
86+
SOUND_FILE = "./.assets/Sounds/NotificationSound.wav" # The path to the sound file
87+
88+
# RUN_FUNCTIONS:
89+
RUN_FUNCTIONS = {
90+
"Play Sound": True, # Set to True to play a sound when the program finishes
91+
}
92+
93+
# Functions Definitions:
94+
95+
def verbose_output(true_string="", false_string=""):
96+
"""
97+
Outputs a message if the VERBOSE constant is set to True.
98+
99+
:param true_string: The string to be outputted if the VERBOSE constant is set to True.
100+
:param false_string: The string to be outputted if the VERBOSE constant is set to False.
101+
:return: None
102+
"""
103+
104+
if VERBOSE and true_string != "": # If the VERBOSE constant is set to True and the true_string is set
105+
print(true_string) # Output the true statement string
106+
elif false_string != "": # If the false_string is set
107+
print(false_string) # Output the false statement string
108+
109+
def verify_filepath_exists(filepath):
110+
"""
111+
Verify if a file or folder exists at the specified path.
112+
113+
:param filepath: Path to the file or folder
114+
:return: True if the file or folder exists, False otherwise
115+
"""
116+
117+
verbose_output(f"{BackgroundColors.GREEN}Verifying if the file or folder exists at the path: {BackgroundColors.CYAN}{filepath}{Style.RESET_ALL}") # Output the verbose message
118+
119+
return os.path.exists(filepath) # Return True if the file or folder exists, False otherwise
120+
121+
def play_sound():
122+
"""
123+
Plays a sound when the program finishes and skips if the operating system is Windows.
124+
125+
:param: None
126+
:return: None
127+
"""
128+
129+
current_os = platform.system() # Get the current operating system
130+
if current_os == "Windows": # If the current operating system is Windows
131+
return # Do nothing
132+
133+
if verify_filepath_exists(SOUND_FILE): # If the sound file exists
134+
if current_os in SOUND_COMMANDS: # If the platform.system() is in the SOUND_COMMANDS dictionary
135+
os.system(f"{SOUND_COMMANDS[current_os]} {SOUND_FILE}") # Play the sound
136+
else: # If the platform.system() is not in the SOUND_COMMANDS dictionary
137+
print(f"{BackgroundColors.RED}The {BackgroundColors.CYAN}{current_os}{BackgroundColors.RED} is not in the {BackgroundColors.CYAN}SOUND_COMMANDS dictionary{BackgroundColors.RED}. Please add it!{Style.RESET_ALL}")
138+
else: # If the sound file does not exist
139+
print(f"{BackgroundColors.RED}Sound file {BackgroundColors.CYAN}{SOUND_FILE}{BackgroundColors.RED} not found. Make sure the file exists.{Style.RESET_ALL}")
140+
141+
def main():
142+
"""
143+
Main function.
144+
145+
:param: None
146+
:return: None
147+
"""
148+
149+
print(f"{BackgroundColors.CLEAR_TERMINAL}{BackgroundColors.BOLD}{BackgroundColors.GREEN}Welcome to the {BackgroundColors.CYAN}WGAN-GP Data Augmentation{BackgroundColors.GREEN} program!{Style.RESET_ALL}", end="\n\n") # Output the welcome message
150+
151+
args = parse_args() # Parse command-line arguments
152+
if args.mode == "train": # If training mode is selected
153+
assert args.csv_path is not None, "Training requires --csv_path" # Ensure CSV path is provided
154+
train(args) # Run training function
155+
elif args.mode == "gen": # If generation mode is selected
156+
assert args.checkpoint is not None, "Generation requires --checkpoint" # Ensure checkpoint is provided
157+
generate(args) # Run generation function
158+
159+
print(f"\n{BackgroundColors.BOLD}{BackgroundColors.GREEN}Program finished.{Style.RESET_ALL}") # Output the end of the program message
160+
161+
atexit.register(play_sound) if RUN_FUNCTIONS["Play Sound"] else None # Register the play_sound function to be called when the program finishes
162+
163+
if __name__ == "__main__":
164+
"""
165+
This is the standard boilerplate that calls the main() function.
166+
167+
:return: None
168+
"""
169+
170+
main() # Call the main function

0 commit comments

Comments
 (0)