Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 77 additions & 17 deletions etna/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def __init__(self, file_path: str, target: str, task_type: str = None, hidden_la
file_path: Path to the .csv dataset
target: Name of the target column
task_type: 'classification', 'regression', or None (auto-detect)
hidden_layers: List of hidden layer sizes (default: [16])
activation: Activation function ('relu', 'leaky_relu', 'sigmoid', default: 'relu')
"""
self.file_path = file_path
self.target = target
Expand Down Expand Up @@ -58,31 +60,62 @@ def __init__(self, file_path: str, target: str, task_type: str = None, hidden_la

# Cached transformed data for persistence-safe prediction
self._cached_X = None
self.val_loss_history = []

def train(self, epochs: int = 100, lr: float = 0.01, batch_size: int = 32, weight_decay: float = 0.0, optimizer: str = 'sgd'):
def train(self, epochs: int = 100, lr: float = 0.01, batch_size: int = 32, weight_decay: float = 0.0, optimizer: str = 'sgd', validation_split: float = 0.2):
"""
Train the model.

Args:
epochs: Number of training epochs
lr: Learning rate
batch_size: Number of samples per gradient update (default: 32)
weight_decay: L2 regularization coefficient (lambda)
optimizer: Optimizer to use ('sgd' or 'adam')
batch_size: Number of samples per gradient update (default: 32, 0 means full batch)
weight_decay: L2 regularization coefficient (lambda). Higher values
lead to smaller weights and help prevent overfitting.
Typical values: 0.0001 to 0.01
optimizer: Optimizer to use ('sgd' or 'adam'). Default is 'sgd'.
Adam optimizer provides better convergence with adaptive learning rates.
validation_split: Fraction of data to use for validation (default: 0.2).
Set to 0.0 to disable validation split.
"""
if _etna_rust is None:
raise ImportError(
"Rust core is not available. Please build the Rust extension "
"before calling model.train()."
)

if not 0.0 <= validation_split < 1.0:
raise ValueError(f"validation_split must be in [0.0, 1.0), got {validation_split}")

if batch_size < 0:
raise ValueError(f"batch_size must be >= 0, got {batch_size}")

print("⚙️ Preprocessing data...")
X, y = self.preprocessor.fit_transform(self.df, self.target)


# Split data into training and validation sets
n_samples = len(X)
if validation_split > 0.0:
split_idx = int(n_samples * (1 - validation_split))
X_train = X[:split_idx]
y_train = y[:split_idx]
X_val = X[split_idx:]
y_val = y[split_idx:]
print(f"📊 Data split: {len(X_train)} training samples, {len(X_val)} validation samples")
else:
X_train = X
y_train = y
X_val = None
y_val = None
print("📊 Using full dataset for training (no validation split)")

# Cache training data for predict() without arguments
self._cached_X = np.array(X)

self.input_dim = len(X[0])
# Use first hidden layer size for current Rust implementation
# TODO: Update when Rust supports multiple hidden layers
self.hidden_dim = self.hidden_layers[0] if self.hidden_layers else 16
self.output_dim = self.preprocessor.output_dim

optimizer_lower = optimizer.lower()
Expand All @@ -94,7 +127,7 @@ def train(self, epochs: int = 100, lr: float = 0.01, batch_size: int = 32, weigh
print(f"🚀 Initializing Rust Core [In: {self.input_dim}, Out: {self.output_dim}]...")
self.rust_model = _etna_rust.EtnaModel(
self.input_dim,
self.hidden_layers,
self.hidden_dim,
self.output_dim,
self.task_code,
self.activation
Expand All @@ -108,21 +141,33 @@ def train(self, epochs: int = 100, lr: float = 0.01, batch_size: int = 32, weigh
else:
print(f"🔥 Training started (Optimizer: {optimizer_display})...")

# Pass optimizer string to Rust backend (it will default to SGD if None or invalid)
new_losses = self.rust_model.train(X, y, epochs, lr, batch_size, weight_decay, optimizer_lower)
# Train in Rust with validation data passed in (optimizer state persists across epochs)
train_losses, val_losses = self.rust_model.train(
X_train,
y_train,
epochs,
lr,
weight_decay,
optimizer_lower,
batch_size,
X_val if validation_split > 0.0 and X_val is not None else None,
y_val if validation_split > 0.0 and y_val is not None else None
)

# Extend history instead of overwriting it
self.loss_history.extend(new_losses)
self.loss_history.extend(train_losses)
if validation_split > 0.0 and val_losses:
self.val_loss_history.extend(val_losses)
print("✅ Training complete!")

def predict(self, data_path: str = None):
"""
Make predictions.

Args:
data_path: Optional path to CSV file. If not provided, uses the
data_path: Optional path to CSV file. If not provided, uses the
training data (useful for evaluating on training set).

Returns:
List of predictions (class labels for classification, values for regression)
"""
Expand Down Expand Up @@ -156,7 +201,7 @@ def predict(self, data_path: str = None):
for p in preds
]
return [float(r) for r in results]

def summary(self):
print("\n Model Summary")
print("=" * 60)
Expand All @@ -166,7 +211,6 @@ def summary(self):
print("Call model.train() before calling summary().")
return


l1_params = (self.input_dim * self.hidden_dim) + self.hidden_dim
print(
f"Layer 1 (Linear): {self.input_dim} -> {self.hidden_dim} "
Expand All @@ -177,7 +221,7 @@ def summary(self):
print(
f"Layer 2 (Linear): {self.hidden_dim} -> {self.output_dim} "
f"| Params: {l2_params}"
)
)

print("=" * 60)
total_params = l1_params + l2_params
Expand All @@ -200,7 +244,9 @@ def save_model(self, path="model_checkpoint.json", run_name="ETNA_Run", mlflow_t

preprocessor_path = path + ".preprocessor.json"
state = self.preprocessor.get_state()
state["_cached_X"] = self._cached_X.tolist() if self._cached_X is not None else None
state["_cached_X"] = (
self._cached_X.tolist() if self._cached_X is not None else None
)
state["_target"] = self.target

with open(preprocessor_path, "w") as f:
Expand All @@ -217,15 +263,27 @@ def save_model(self, path="model_checkpoint.json", run_name="ETNA_Run", mlflow_t
with mlflow.start_run(run_name=run_name):
mlflow.log_param("task_type", self.task_type)
mlflow.log_param("target_column", self.target)

print(f"📈 Logging {len(self.loss_history)} training metrics points...")
for epoch, loss in enumerate(self.loss_history):
mlflow.log_metric("loss", loss, step=epoch)

# Log validation loss if available
if self.val_loss_history:
print(f"📈 Logging {len(self.val_loss_history)} validation metrics points...")
for epoch, val_loss in enumerate(self.val_loss_history):
mlflow.log_metric("val_loss", val_loss, step=epoch)

mlflow.log_artifact(path)
mlflow.log_artifact(preprocessor_path)

print("Model saved & tracked!")
except ImportError:
print("MLflow not installed. Skipping remote tracking.")
elif os.environ.get("ETNA_DISABLE_MLFLOW") == "1":
print(f"Model saved locally to {path}. (MLflow tracking disabled)")
else:
print(f"Model saved locally to {path}. (MLflow tracking skipped)")
print(f"Model saved locally to {path}. (MLflow tracking skipped - no URI provided)")

@classmethod
def load(cls, path: str):
Expand Down Expand Up @@ -275,6 +333,8 @@ def load(cls, path: str):
self.file_path = None
self.df = None
self.loss_history = []
self.val_loss_history = []

print("✅ Model loaded successfully!")
return self

65 changes: 46 additions & 19 deletions etna_core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
// Rust-Python bridge (pyo3)

#![allow(dead_code)]

mod model;
mod layers;
mod loss_function;
mod optimizer;


use pyo3::prelude::*;
use pyo3::types::PyList;

use crate::model::{SimpleNN, OptimizerType};
use crate::layers::Activation;
use crate::model::SimpleNN;
use crate::model::OptimizerType;

/// Safe conversion helper
fn pylist_to_vec2(pylist: &Bound<'_, PyList>) -> PyResult<Vec<Vec<f32>>> {
pylist.iter()
.map(|item| item.extract::<Vec<f32>>())
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e)
.map_err(|e| PyErr::from(e))
}

/// Python Class Wrapper
Expand All @@ -37,6 +39,7 @@ impl EtnaModel {
task_type: usize,
activation: Option<String>,
) -> Self {
// Parse activation string (default: ReLU)
let act = match activation.as_deref().unwrap_or("relu") {
"leaky_relu" => Activation::LeakyReLU,
"sigmoid" => Activation::Sigmoid,
Expand All @@ -48,7 +51,7 @@ impl EtnaModel {
}
}

#[pyo3(signature = (x, y, epochs, lr, batch_size=32, weight_decay=0.0, optimizer="sgd"))]
#[pyo3(signature = (x, y, epochs, lr, batch_size=32, weight_decay=0.0, optimizer="sgd", x_val=None, y_val=None))]
fn train(
&mut self,
x: &Bound<'_, PyList>,
Expand All @@ -58,34 +61,57 @@ impl EtnaModel {
batch_size: usize,
weight_decay: f32,
optimizer: &str,
) -> PyResult<Vec<f32>> {
x_val: Option<&Bound<'_, PyList>>,
y_val: Option<&Bound<'_, PyList>>,
) -> PyResult<(Vec<f32>, Vec<f32>)> {
let x_vec = pylist_to_vec2(x)?;
let y_vec = pylist_to_vec2(y)?;

// Parse optimizer string (default to SGD if not specified or invalid)
let optimizer_type = match optimizer {
"adam" => OptimizerType::Adam,
_ => OptimizerType::SGD,
_ => OptimizerType::SGD, // Default to SGD for backward compatibility
};

// Convert optional validation data
let x_val_opt = match x_val {
Some(v) => Some(pylist_to_vec2(v)?),
None => None,
};
let y_val_opt = match y_val {
Some(v) => Some(pylist_to_vec2(v)?),
None => None,
};

// Capture the history returned by Rust
let history = self.inner.train(
&x_vec,
&y_vec,
epochs,
lr,
weight_decay,
optimizer_type,
batch_size,
);

// Return it to Python
Ok(history)
// Capture the history returned by Rust (both train and val losses)
let (train_history, val_history) = self.inner.train(
&x_vec,
&y_vec,
epochs,
lr,
weight_decay,
optimizer_type,
batch_size,
x_val_opt.as_ref(),
y_val_opt.as_ref()
);

// Return both histories to Python
Ok((train_history, val_history))
}

fn predict(&mut self, x: &Bound<'_, PyList>) -> PyResult<Vec<f32>> {
let x_vec = pylist_to_vec2(x)?;
Ok(self.inner.predict(&x_vec))
}

/// Get raw forward pass outputs (probabilities for classification, values for regression)
/// This is useful for calculating validation loss
fn forward(&mut self, x: &Bound<'_, PyList>) -> PyResult<Vec<Vec<f32>>> {
let x_vec = pylist_to_vec2(x)?;
Ok(self.inner.forward(&x_vec))
}

fn save(&self, path: String) -> PyResult<()> {
self.inner.save(&path).map_err(|e| {
pyo3::exceptions::PyIOError::new_err(format!("Failed to save model: {}", e))
Expand All @@ -107,3 +133,4 @@ fn _etna_rust(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<EtnaModel>()?;
Ok(())
}

Loading