@@ -10,8 +10,14 @@ class TransactionGraphBuilder:
1010 """Builds transaction graphs from tabular data. Users and merchants become
1111 nodes, transactions become edges."""
1212
13- def __init__ (self , user_col = "user_id" , merchant_col = "merchant_id" ,
14- amount_col = "amount" , timestamp_col = "timestamp" , label_col = "is_fraud" ):
13+ def __init__ (
14+ self ,
15+ user_col = "user_id" ,
16+ merchant_col = "merchant_id" ,
17+ amount_col = "amount" ,
18+ timestamp_col = "timestamp" ,
19+ label_col = "is_fraud" ,
20+ ):
1521 self .user_col = user_col
1622 self .merchant_col = merchant_col
1723 self .amount_col = amount_col
@@ -181,7 +187,9 @@ def _compute_node_features(self, transactions, num_nodes):
181187
182188 return features
183189
184- def get_train_test_masks (self , num_samples , train_ratio = 0.7 , val_ratio = 0.15 , seed = 42 ):
190+ def get_train_test_masks (
191+ self , num_samples , train_ratio = 0.7 , val_ratio = 0.15 , seed = 42
192+ ):
185193 """Create train/val/test masks."""
186194 np .random .seed (seed )
187195 indices = np .random .permutation (num_samples )
@@ -203,7 +211,9 @@ def get_train_test_masks(self, num_samples, train_ratio=0.7, val_ratio=0.15, see
203211
204212 return train_mask , val_mask , test_mask
205213
206- def build_hetero_graph (self , transactions , device_col = "device_id" , include_features = True ):
214+ def build_hetero_graph (
215+ self , transactions , device_col = "device_id" , include_features = True
216+ ):
207217 """Build heterogeneous graph with user/merchant/device node types."""
208218 data = HeteroData ()
209219
@@ -374,4 +384,3 @@ def _compute_merchant_features(self, transactions, num_merchants):
374384 features [:, i ] = (col - col .mean ()) / col .std ()
375385
376386 return features
377-
0 commit comments