Skip to content

Commit 7bd7ec0

Browse files
committed
fix: black formatting
1 parent f6562d6 commit 7bd7ec0

File tree

11 files changed

+141
-50
lines changed

11 files changed

+141
-50
lines changed

src/data/dataset.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,14 @@
1313
class FraudDataset(InMemoryDataset):
1414
"""Wraps transaction data into a PyG InMemoryDataset."""
1515

16-
def __init__(self, root, transactions=None, transform=None,
17-
pre_transform=None, pre_filter=None):
16+
def __init__(
17+
self,
18+
root,
19+
transactions=None,
20+
transform=None,
21+
pre_transform=None,
22+
pre_filter=None,
23+
):
1824
self.transactions = transactions
1925
super().__init__(root, transform, pre_transform, pre_filter)
2026
self.load(self.processed_paths[0])
@@ -62,8 +68,9 @@ def __repr__(self):
6268
return f"{self.__class__.__name__}()"
6369

6470

65-
def create_synthetic_fraud_data(num_users=1000, num_merchants=200,
66-
num_transactions=10000, fraud_rate=0.05, seed=42):
71+
def create_synthetic_fraud_data(
72+
num_users=1000, num_merchants=200, num_transactions=10000, fraud_rate=0.05, seed=42
73+
):
6774
"""Generate fake transaction data for testing."""
6875
np.random.seed(seed)
6976

@@ -135,7 +142,9 @@ def load_kaggle_fraud_data(path):
135142
return df[["user_id", "merchant_id", "amount", "timestamp", "is_fraud"]]
136143

137144

138-
def split_temporal(transactions, timestamp_col="timestamp", train_ratio=0.7, val_ratio=0.15):
145+
def split_temporal(
146+
transactions, timestamp_col="timestamp", train_ratio=0.7, val_ratio=0.15
147+
):
139148
"""Split by time so we don't leak future data into training."""
140149
df = transactions.sort_values(timestamp_col).reset_index(drop=True)
141150

src/data/features.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,13 @@
77
class FeatureExtractor:
88
"""Extracts behavioral, temporal and network features from transaction data."""
99

10-
def __init__(self, user_col="user_id", merchant_col="merchant_id",
11-
amount_col="amount", timestamp_col="timestamp"):
10+
def __init__(
11+
self,
12+
user_col="user_id",
13+
merchant_col="merchant_id",
14+
amount_col="amount",
15+
timestamp_col="timestamp",
16+
):
1217
self.user_col = user_col
1318
self.merchant_col = merchant_col
1419
self.amount_col = amount_col
@@ -215,8 +220,9 @@ def get_feature_names(self):
215220
]
216221

217222

218-
def compute_velocity_features(transactions, user_col="user_id",
219-
timestamp_col="timestamp", windows=[1, 6, 24]):
223+
def compute_velocity_features(
224+
transactions, user_col="user_id", timestamp_col="timestamp", windows=[1, 6, 24]
225+
):
220226
"""Transaction frequency in rolling time windows."""
221227
df = transactions.copy()
222228
df["timestamp"] = pd.to_datetime(df[timestamp_col])
@@ -236,7 +242,9 @@ def compute_velocity_features(transactions, user_col="user_id",
236242
return df
237243

238244

239-
def compute_graph_features(transactions, user_col="user_id", merchant_col="merchant_id"):
245+
def compute_graph_features(
246+
transactions, user_col="user_id", merchant_col="merchant_id"
247+
):
240248
"""Graph-based features using NetworkX (degree, clustering coeff)."""
241249
import networkx as nx
242250

src/data/graph_builder.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,14 @@ class TransactionGraphBuilder:
1010
"""Builds transaction graphs from tabular data. Users and merchants become
1111
nodes, transactions become edges."""
1212

13-
def __init__(self, user_col="user_id", merchant_col="merchant_id",
14-
amount_col="amount", timestamp_col="timestamp", label_col="is_fraud"):
13+
def __init__(
14+
self,
15+
user_col="user_id",
16+
merchant_col="merchant_id",
17+
amount_col="amount",
18+
timestamp_col="timestamp",
19+
label_col="is_fraud",
20+
):
1521
self.user_col = user_col
1622
self.merchant_col = merchant_col
1723
self.amount_col = amount_col
@@ -181,7 +187,9 @@ def _compute_node_features(self, transactions, num_nodes):
181187

182188
return features
183189

184-
def get_train_test_masks(self, num_samples, train_ratio=0.7, val_ratio=0.15, seed=42):
190+
def get_train_test_masks(
191+
self, num_samples, train_ratio=0.7, val_ratio=0.15, seed=42
192+
):
185193
"""Create train/val/test masks."""
186194
np.random.seed(seed)
187195
indices = np.random.permutation(num_samples)
@@ -203,7 +211,9 @@ def get_train_test_masks(self, num_samples, train_ratio=0.7, val_ratio=0.15, see
203211

204212
return train_mask, val_mask, test_mask
205213

206-
def build_hetero_graph(self, transactions, device_col="device_id", include_features=True):
214+
def build_hetero_graph(
215+
self, transactions, device_col="device_id", include_features=True
216+
):
207217
"""Build heterogeneous graph with user/merchant/device node types."""
208218
data = HeteroData()
209219

@@ -374,4 +384,3 @@ def _compute_merchant_features(self, transactions, num_merchants):
374384
features[:, i] = (col - col.mean()) / col.std()
375385

376386
return features
377-

src/inference/explainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _compute_feature_importance(self, edge_idx):
148148

149149
return importance
150150

151-
def _get_prediction(self, node_idx: int) -> Dict:
151+
def _get_prediction(self, node_idx: int) -> dict:
152152
"""Get prediction for a node."""
153153
with torch.no_grad():
154154
out = self.model(
@@ -164,7 +164,7 @@ def _get_prediction(self, node_idx: int) -> Dict:
164164
"confidence": probs.max().item(),
165165
}
166166

167-
def _get_edge_prediction(self, edge_idx: int) -> Dict:
167+
def _get_edge_prediction(self, edge_idx: int) -> dict:
168168
"""Get prediction for an edge."""
169169
with torch.no_grad():
170170
out = self.model(

src/models/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .gat import FraudGAT
2-
from .gin import FraudGIN
3-
from .graphsage import FraudGraphSAGE
4-
from .hetero_gnn import HeteroFraudGNN
1+
from .gat import FraudGAT # noqa: F401
2+
from .gin import FraudGIN # noqa: F401
3+
from .graphsage import FraudGraphSAGE # noqa: F401
4+
from .hetero_gnn import HeteroFraudGNN # noqa: F401

src/models/gat.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,16 @@ class FraudGAT(nn.Module):
1010
"""Multi-head attention GNN. The attention weights are useful
1111
for interpretability — you can see which neighbors matter most."""
1212

13-
def __init__(self, in_channels, hidden_channels, out_channels,
14-
num_layers=3, heads=4, dropout=0.3, attention_dropout=0.3):
13+
def __init__(
14+
self,
15+
in_channels,
16+
hidden_channels,
17+
out_channels,
18+
num_layers=3,
19+
heads=4,
20+
dropout=0.3,
21+
attention_dropout=0.3,
22+
):
1523
super().__init__()
1624

1725
self.num_layers = num_layers
@@ -119,8 +127,15 @@ def get_attention_weights(self, x, edge_index):
119127
class EdgeFraudGAT(nn.Module):
120128
"""Edge-level GAT — classifies transactions as fraud/legit."""
121129

122-
def __init__(self, in_channels, hidden_channels, edge_channels=0,
123-
num_layers=3, heads=4, dropout=0.3):
130+
def __init__(
131+
self,
132+
in_channels,
133+
hidden_channels,
134+
edge_channels=0,
135+
num_layers=3,
136+
heads=4,
137+
dropout=0.3,
138+
):
124139
super().__init__()
125140

126141
self.node_encoder = FraudGAT(

src/models/gin.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,15 @@
1313

1414
class FraudGIN(nn.Module):
1515

16-
def __init__(self, in_channels, hidden_channels, out_channels,
17-
num_layers=3, dropout=0.3, train_eps=True):
16+
def __init__(
17+
self,
18+
in_channels,
19+
hidden_channels,
20+
out_channels,
21+
num_layers=3,
22+
dropout=0.3,
23+
train_eps=True,
24+
):
1825
super().__init__()
1926

2027
self.num_layers = num_layers
@@ -73,8 +80,15 @@ class GINWithJK(nn.Module):
7380
"""GIN + Jumping Knowledge — concatenates representations from all layers.
7481
Helps when graph has varying depths/diameters."""
7582

76-
def __init__(self, in_channels, hidden_channels, out_channels,
77-
num_layers=3, dropout=0.3, jk_mode="cat"):
83+
def __init__(
84+
self,
85+
in_channels,
86+
hidden_channels,
87+
out_channels,
88+
num_layers=3,
89+
dropout=0.3,
90+
jk_mode="cat",
91+
):
7892
super().__init__()
7993

8094
self.num_layers = num_layers
@@ -141,8 +155,9 @@ def forward(self, x, edge_index, edge_attr=None):
141155
class EdgeFraudGIN(nn.Module):
142156
"""Edge-level GIN."""
143157

144-
def __init__(self, in_channels, hidden_channels, edge_channels=0,
145-
num_layers=3, dropout=0.3):
158+
def __init__(
159+
self, in_channels, hidden_channels, edge_channels=0, num_layers=3, dropout=0.3
160+
):
146161
super().__init__()
147162

148163
self.node_encoder = FraudGIN(

src/models/graphsage.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,15 @@
99
class FraudGraphSAGE(nn.Module):
1010
"""GraphSAGE for node-level fraud detection."""
1111

12-
def __init__(self, in_channels, hidden_channels, out_channels,
13-
num_layers=3, dropout=0.3, aggregator="mean"):
12+
def __init__(
13+
self,
14+
in_channels,
15+
hidden_channels,
16+
out_channels,
17+
num_layers=3,
18+
dropout=0.3,
19+
aggregator="mean",
20+
):
1421
super().__init__()
1522

1623
self.num_layers = num_layers
@@ -65,8 +72,9 @@ def get_embeddings(self, x, edge_index):
6572
class EdgeFraudGraphSAGE(nn.Module):
6673
"""Edge-level fraud detection — predicts per-transaction."""
6774

68-
def __init__(self, in_channels, hidden_channels, edge_channels=0,
69-
num_layers=3, dropout=0.3):
75+
def __init__(
76+
self, in_channels, hidden_channels, edge_channels=0, num_layers=3, dropout=0.3
77+
):
7078
super().__init__()
7179

7280
self.node_encoder = FraudGraphSAGE(

src/models/hetero_gnn.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,15 @@
1414
class HeteroFraudGNN(nn.Module):
1515
"""Heterogeneous GNN with per-type message passing."""
1616

17-
def __init__(self, node_types, edge_types, hidden_channels, out_channels,
18-
num_layers=3, dropout=0.3):
17+
def __init__(
18+
self,
19+
node_types,
20+
edge_types,
21+
hidden_channels,
22+
out_channels,
23+
num_layers=3,
24+
dropout=0.3,
25+
):
1926
super().__init__()
2027

2128
self.node_types = node_types
@@ -118,9 +125,16 @@ def get_embeddings(self, x_dict, edge_index_dict):
118125
class HeteroEdgeFraudGNN(nn.Module):
119126
"""Edge-level hetero GNN — classifies transaction edges."""
120127

121-
def __init__(self, node_types, edge_types, hidden_channels, edge_channels=0,
122-
num_layers=3, dropout=0.3,
123-
target_edge_type=("user", "transacts", "merchant")):
128+
def __init__(
129+
self,
130+
node_types,
131+
edge_types,
132+
hidden_channels,
133+
edge_channels=0,
134+
num_layers=3,
135+
dropout=0.3,
136+
target_edge_type=("user", "transacts", "merchant"),
137+
):
124138
super().__init__()
125139

126140
self.target_edge_type = target_edge_type

src/training/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .losses import FocalLoss
2-
from .trainer import GNNTrainer
1+
from .losses import FocalLoss # noqa: F401
2+
from .trainer import GNNTrainer # noqa: F401

0 commit comments

Comments
 (0)