@@ -32,6 +32,7 @@ def __init__(self, file_path: str, target: str, task_type: str = None, hidden_la
3232 self .target = target
3333 self .df = load_data (file_path )
3434 self .loss_history = []
35+ self .val_loss_history = []
3536
3637 # --- SEED LOGIC ---
3738 self .seed = seed
@@ -67,6 +68,34 @@ def __init__(self, file_path: str, target: str, task_type: str = None, hidden_la
6768 # Cached transformed data for persistence-safe prediction
6869 self ._cached_X = None
6970
71+ def _calculate_validation_loss (self , X_val , y_val ):
72+ """
73+ Calculate validation loss using the Rust model's forward pass.
74+
75+ Args:
76+ X_val: Validation features (numpy array).
77+ y_val: Validation targets (numpy array).
78+
79+ Returns:
80+ float: Validation loss value.
81+ """
82+ preds = self .rust_model .forward (X_val )
83+
84+ if self .task_type == "classification" :
85+ # Cross-entropy loss
86+ loss = 0.0
87+ for p_row , y_row in zip (preds , y_val ):
88+ for p_val , y_true in zip (p_row , y_row ):
89+ loss += - y_true * np .log (p_val + 1e-7 )
90+ return loss / len (preds )
91+ else :
92+ # MSE loss
93+ loss = 0.0
94+ for p_row , y_row in zip (preds , y_val ):
95+ for p_val , y_true in zip (p_row , y_row ):
96+ loss += (p_val - y_true ) ** 2
97+ return loss / len (preds )
98+
7099 def train (
71100 self ,
72101 epochs : int = 100 ,
@@ -77,6 +106,7 @@ def train(
77106 early_stopping : bool = False ,
78107 patience : int = 10 ,
79108 restore_best : bool = True ,
109+ validation_split : float = 0.2 ,
80110 ):
81111 """
82112 Train the model.
@@ -90,21 +120,53 @@ def train(
90120 early_stopping: If True, stop training when loss stops improving.
91121 patience: Number of epochs with no improvement before stopping.
92122 restore_best: If True, restore weights from the best epoch.
123+ validation_split: Fraction of data to use for validation (0.0 to 1.0).
124+ Set to 0.0 to disable validation. Default: 0.2.
93125 """
94126 if _etna_rust is None :
95127 raise ImportError (
96128 "Rust core is not available. Please build the Rust extension "
97129 "before calling model.train()."
98130 )
99131
132+ if not (0.0 <= validation_split < 1.0 ):
133+ raise ValueError (
134+ f"validation_split must be >= 0.0 and < 1.0, got { validation_split } "
135+ )
136+
100137 print ("[*] Preprocessing data..." )
101138 X , y = self .preprocessor .fit_transform (self .df , self .target )
102139
103140 # Ensure contiguous float32 arrays for zero-copy transfer to Rust
104141 X = np .ascontiguousarray (X , dtype = np .float32 )
105142 y = np .ascontiguousarray (y , dtype = np .float32 )
106143
107- # Cache training data for predict() without arguments
144+ # --- Validation Split ---
145+ X_val = None
146+ y_val = None
147+ if validation_split > 0.0 :
148+ n_samples = X .shape [0 ]
149+ n_val = max (1 , int (n_samples * validation_split ))
150+
151+ # Shuffle indices before splitting (use seed for reproducibility)
152+ rng = np .random .default_rng (self .seed )
153+ indices = rng .permutation (n_samples )
154+
155+ val_indices = indices [:n_val ]
156+ train_indices = indices [n_val :]
157+
158+ X_val = np .ascontiguousarray (X [val_indices ], dtype = np .float32 )
159+ y_val = np .ascontiguousarray (y [val_indices ], dtype = np .float32 )
160+ X_train = np .ascontiguousarray (X [train_indices ], dtype = np .float32 )
161+ y_train = np .ascontiguousarray (y [train_indices ], dtype = np .float32 )
162+
163+ print (f"[*] Data split: { len (train_indices )} training samples, { len (val_indices )} validation samples" )
164+ else :
165+ X_train = X
166+ y_train = y
167+ print ("[*] Validation disabled (validation_split=0.0)" )
168+
169+ # Cache full data for predict() without arguments
108170 self ._cached_X = X
109171
110172 self .input_dim = X .shape [1 ]
@@ -114,7 +176,6 @@ def train(
114176 if optimizer_lower not in ['sgd' , 'adam' ]:
115177 raise ValueError (f"Unsupported optimizer '{ optimizer } '. Choose 'sgd' or 'adam'." )
116178
117- # LOGICAL FIX: Only initialize if model doesn't exist
118179 # Only initialize if model doesn't exist (supports incremental training)
119180 if self .rust_model is None :
120181 print (f"[*] Initializing Rust Core [In: { self .input_dim } , Out: { self .output_dim } ]..." )
@@ -138,15 +199,24 @@ def train(
138199 # Create tqdm progress bar
139200 pbar = tqdm (total = epochs , desc = "Training" , unit = "epoch" )
140201
202+ # Storage for per-epoch validation losses computed inside callback
203+ epoch_val_losses = []
204+
141205 # Callback function that Rust calls after each epoch
142206 def progress_callback (epoch , total , loss ):
143207 pbar .update (1 )
144- pbar .set_description (f"Loss: { loss :.4f} " )
208+ # Compute validation loss if validation data is available
209+ if X_val is not None and y_val is not None :
210+ val_loss = self ._calculate_validation_loss (X_val , y_val )
211+ epoch_val_losses .append (val_loss )
212+ pbar .set_description (f"Loss: { loss :.4f} | Val Loss: { val_loss :.4f} " )
213+ else :
214+ pbar .set_description (f"Loss: { loss :.4f} " )
145215
146216 # Single Rust call - training loop stays in Rust for performance
147217 new_losses = self .rust_model .train (
148- X ,
149- y ,
218+ X_train ,
219+ y_train ,
150220 epochs ,
151221 lr ,
152222 batch_size ,
@@ -160,6 +230,7 @@ def progress_callback(epoch, total, loss):
160230
161231 pbar .close ()
162232 self .loss_history .extend (new_losses )
233+ self .val_loss_history .extend (epoch_val_losses )
163234 print ("[+] Training complete!" )
164235
165236 def predict (self , data_path : str = None ):
@@ -267,6 +338,8 @@ def save_model(self, path="model_checkpoint.json", run_name="ETNA_Run", mlflow_t
267338 mlflow .log_param ("target_column" , self .target )
268339 for epoch , loss in enumerate (self .loss_history ):
269340 mlflow .log_metric ("loss" , loss , step = epoch )
341+ for epoch , val_loss in enumerate (self .val_loss_history ):
342+ mlflow .log_metric ("val_loss" , val_loss , step = epoch )
270343 mlflow .log_artifact (path )
271344 mlflow .log_artifact (preprocessor_path )
272345 print ("Model saved & tracked!" )
@@ -323,6 +396,7 @@ def load(cls, path: str):
323396 self .file_path = None
324397 self .df = None
325398 self .loss_history = []
399+ self .val_loss_history = []
326400
327401 print ("[+] Model loaded successfully!" )
328402 return self
0 commit comments