Skip to content

Commit 4886721

Browse files
authored
Merge pull request #78 from verdhanyash/feat/validation-split-mlflow-tracking
feat: Validation Split & MLflow Tracking (fixes #22)
2 parents 2c18033 + 8dd5418 commit 4886721

File tree

4 files changed

+320
-6
lines changed

4 files changed

+320
-6
lines changed

etna/api.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(self, file_path: str, target: str, task_type: str = None, hidden_la
3232
self.target = target
3333
self.df = load_data(file_path)
3434
self.loss_history = []
35+
self.val_loss_history = []
3536

3637
# --- SEED LOGIC ---
3738
self.seed = seed
@@ -67,6 +68,34 @@ def __init__(self, file_path: str, target: str, task_type: str = None, hidden_la
6768
# Cached transformed data for persistence-safe prediction
6869
self._cached_X = None
6970

71+
def _calculate_validation_loss(self, X_val, y_val):
72+
"""
73+
Calculate validation loss using the Rust model's forward pass.
74+
75+
Args:
76+
X_val: Validation features (numpy array).
77+
y_val: Validation targets (numpy array).
78+
79+
Returns:
80+
float: Validation loss value.
81+
"""
82+
preds = self.rust_model.forward(X_val)
83+
84+
if self.task_type == "classification":
85+
# Cross-entropy loss
86+
loss = 0.0
87+
for p_row, y_row in zip(preds, y_val):
88+
for p_val, y_true in zip(p_row, y_row):
89+
loss += -y_true * np.log(p_val + 1e-7)
90+
return loss / len(preds)
91+
else:
92+
# MSE loss
93+
loss = 0.0
94+
for p_row, y_row in zip(preds, y_val):
95+
for p_val, y_true in zip(p_row, y_row):
96+
loss += (p_val - y_true) ** 2
97+
return loss / len(preds)
98+
7099
def train(
71100
self,
72101
epochs: int = 100,
@@ -77,6 +106,7 @@ def train(
77106
early_stopping: bool = False,
78107
patience: int = 10,
79108
restore_best: bool = True,
109+
validation_split: float = 0.2,
80110
):
81111
"""
82112
Train the model.
@@ -90,21 +120,53 @@ def train(
90120
early_stopping: If True, stop training when loss stops improving.
91121
patience: Number of epochs with no improvement before stopping.
92122
restore_best: If True, restore weights from the best epoch.
123+
validation_split: Fraction of data to use for validation (0.0 to 1.0).
124+
Set to 0.0 to disable validation. Default: 0.2.
93125
"""
94126
if _etna_rust is None:
95127
raise ImportError(
96128
"Rust core is not available. Please build the Rust extension "
97129
"before calling model.train()."
98130
)
99131

132+
if not (0.0 <= validation_split < 1.0):
133+
raise ValueError(
134+
f"validation_split must be >= 0.0 and < 1.0, got {validation_split}"
135+
)
136+
100137
print("[*] Preprocessing data...")
101138
X, y = self.preprocessor.fit_transform(self.df, self.target)
102139

103140
# Ensure contiguous float32 arrays for zero-copy transfer to Rust
104141
X = np.ascontiguousarray(X, dtype=np.float32)
105142
y = np.ascontiguousarray(y, dtype=np.float32)
106143

107-
# Cache training data for predict() without arguments
144+
# --- Validation Split ---
145+
X_val = None
146+
y_val = None
147+
if validation_split > 0.0:
148+
n_samples = X.shape[0]
149+
n_val = max(1, int(n_samples * validation_split))
150+
151+
# Shuffle indices before splitting (use seed for reproducibility)
152+
rng = np.random.default_rng(self.seed)
153+
indices = rng.permutation(n_samples)
154+
155+
val_indices = indices[:n_val]
156+
train_indices = indices[n_val:]
157+
158+
X_val = np.ascontiguousarray(X[val_indices], dtype=np.float32)
159+
y_val = np.ascontiguousarray(y[val_indices], dtype=np.float32)
160+
X_train = np.ascontiguousarray(X[train_indices], dtype=np.float32)
161+
y_train = np.ascontiguousarray(y[train_indices], dtype=np.float32)
162+
163+
print(f"[*] Data split: {len(train_indices)} training samples, {len(val_indices)} validation samples")
164+
else:
165+
X_train = X
166+
y_train = y
167+
print("[*] Validation disabled (validation_split=0.0)")
168+
169+
# Cache full data for predict() without arguments
108170
self._cached_X = X
109171

110172
self.input_dim = X.shape[1]
@@ -114,7 +176,6 @@ def train(
114176
if optimizer_lower not in ['sgd', 'adam']:
115177
raise ValueError(f"Unsupported optimizer '{optimizer}'. Choose 'sgd' or 'adam'.")
116178

117-
# LOGICAL FIX: Only initialize if model doesn't exist
118179
# Only initialize if model doesn't exist (supports incremental training)
119180
if self.rust_model is None:
120181
print(f"[*] Initializing Rust Core [In: {self.input_dim}, Out: {self.output_dim}]...")
@@ -138,15 +199,24 @@ def train(
138199
# Create tqdm progress bar
139200
pbar = tqdm(total=epochs, desc="Training", unit="epoch")
140201

202+
# Storage for per-epoch validation losses computed inside callback
203+
epoch_val_losses = []
204+
141205
# Callback function that Rust calls after each epoch
142206
def progress_callback(epoch, total, loss):
143207
pbar.update(1)
144-
pbar.set_description(f"Loss: {loss:.4f}")
208+
# Compute validation loss if validation data is available
209+
if X_val is not None and y_val is not None:
210+
val_loss = self._calculate_validation_loss(X_val, y_val)
211+
epoch_val_losses.append(val_loss)
212+
pbar.set_description(f"Loss: {loss:.4f} | Val Loss: {val_loss:.4f}")
213+
else:
214+
pbar.set_description(f"Loss: {loss:.4f}")
145215

146216
# Single Rust call - training loop stays in Rust for performance
147217
new_losses = self.rust_model.train(
148-
X,
149-
y,
218+
X_train,
219+
y_train,
150220
epochs,
151221
lr,
152222
batch_size,
@@ -160,6 +230,7 @@ def progress_callback(epoch, total, loss):
160230

161231
pbar.close()
162232
self.loss_history.extend(new_losses)
233+
self.val_loss_history.extend(epoch_val_losses)
163234
print("[+] Training complete!")
164235

165236
def predict(self, data_path: str = None):
@@ -267,6 +338,8 @@ def save_model(self, path="model_checkpoint.json", run_name="ETNA_Run", mlflow_t
267338
mlflow.log_param("target_column", self.target)
268339
for epoch, loss in enumerate(self.loss_history):
269340
mlflow.log_metric("loss", loss, step=epoch)
341+
for epoch, val_loss in enumerate(self.val_loss_history):
342+
mlflow.log_metric("val_loss", val_loss, step=epoch)
270343
mlflow.log_artifact(path)
271344
mlflow.log_artifact(preprocessor_path)
272345
print("Model saved & tracked!")
@@ -323,6 +396,7 @@ def load(cls, path: str):
323396
self.file_path = None
324397
self.df = None
325398
self.loss_history = []
399+
self.val_loss_history = []
326400

327401
print("[+] Model loaded successfully!")
328402
return self

etna_core/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ impl EtnaModel {
100100
Ok(history)
101101
}
102102

103+
/// Expose raw forward pass outputs (pre-argmax) for validation loss computation.
104+
fn forward(&mut self, x: PyReadonlyArray2<'_, f32>) -> PyResult<Vec<Vec<f32>>> {
105+
let x_vec = ndarray_to_vec2(x);
106+
Ok(self.inner.forward(&x_vec))
107+
}
108+
103109
fn predict(&mut self, x: PyReadonlyArray2<'_, f32>) -> PyResult<Vec<f32>> {
104110
let x_vec = ndarray_to_vec2(x);
105111
Ok(self.inner.predict(&x_vec))

tests/test_tqdm_progress.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def mock_train(
6565

6666
# Create model and train with progress bar
6767
model = etna.api.Model("dummy.csv", "target", task_type="classification")
68-
model.train(epochs=5, lr=0.01)
68+
model.train(epochs=5, lr=0.01, validation_split=0.0)
6969

7070
# Verify train was called only ONCE (all epochs in Rust)
7171
assert mock_model.train.call_count == 1, f"Expected 1 train call, got {mock_model.train.call_count}"

0 commit comments

Comments
 (0)