-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_adversarial.py
More file actions
66 lines (47 loc) · 2.09 KB
/
train_adversarial.py
File metadata and controls
66 lines (47 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import logging
import pickle
from pathlib import Path
from src.adversarial.agent import FraudsterAgent
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def adversarial_training(
initial_model_path: str = "models/fraud_model.pkl",
feature_engineer_path: str = "models/feature_engineer.pkl",
n_rounds: int = 5,
n_episodes_per_round: int = 3,
):
"""Run adversarial training rounds against the current fraud model."""
logger.info("Loading initial model and feature engineer")
with open(initial_model_path, "rb") as f:
model = pickle.load(f)
with open(feature_engineer_path, "rb") as f:
feature_engineer = pickle.load(f)
agent = FraudsterAgent()
history: dict[str, list] = {"round": [], "evasion_rate": [], "q_table_size": []}
for round_num in range(n_rounds):
logger.info("Round %s/%s", round_num + 1, n_rounds)
round_evasion_rates = []
for episode in range(n_episodes_per_round):
metrics = agent.train_episode(model, feature_engineer, n_transactions=100)
round_evasion_rates.append(metrics["evasion_rate"])
logger.info(
f"Episode {episode + 1}: "
f"Evasion={metrics['evasion_rate']:.1%}, "
f"Q-table size={metrics['q_table_size']}"
)
avg_evasion = sum(round_evasion_rates) / len(round_evasion_rates)
history["round"].append(round_num)
history["evasion_rate"].append(avg_evasion)
history["q_table_size"].append(len(agent.q_table))
logger.info(f"Round {round_num + 1} avg evasion rate: {avg_evasion:.1%}")
logger.info("Saving adversarial agent")
Path("models").mkdir(exist_ok=True)
agent.save("models/adversarial_agent.pkl")
logger.info("Adversarial training complete")
logger.info(f"Final evasion rate: {history['evasion_rate'][-1]:.1%}")
logger.info(f"Q-table size: {history['q_table_size'][-1]}")
return history
if __name__ == "__main__":
history = adversarial_training()