|
34 | 34 |
|
35 | 35 | print(f"RMSE for {datasets_names[i]} : {root_mean_squared_error(preds, y_test)}") |
36 | 36 |
|
| 37 | + # Example 1: Adam optimizer (Optax) |
| 38 | + regr = ns.ElasticNet2Regressor( |
| 39 | + solver="adam", # Optax optimizer name |
| 40 | + learning_rate=0.01, # Learning rate |
| 41 | + max_iter=1000, # Max iterations |
| 42 | + tol=1e-4, # Tolerance for early stopping |
| 43 | + verbose=True # Print progress |
| 44 | + ) |
| 45 | + start = time() |
| 46 | + regr.fit(X_train, y_train) |
| 47 | + preds = regr.predict(X_test) |
| 48 | + print(f"Adam - RMSE for {datasets_names[i]}: {root_mean_squared_error(preds, y_test)}") |
| 49 | + print(f"Elapsed: {time() - start:.2f}s\n") |
| 50 | + |
| 51 | + # Example 2: SGD with momentum (Optax) |
| 52 | + regr = ns.ElasticNet2Regressor( |
| 53 | + solver="sgd", # Stochastic Gradient Descent |
| 54 | + learning_rate=0.001, # Smaller learning rate for SGD |
| 55 | + max_iter=1500, |
| 56 | + type_loss='quantile', # Quantile regression |
| 57 | + quantile=0.5 # Median regression |
| 58 | + ) |
| 59 | + start = time() |
| 60 | + regr.fit(X_train, y_train) |
| 61 | + preds = regr.predict(X_test) |
| 62 | + print(f"SGD (Quantile) - RMSE for {datasets_names[i]}: {root_mean_squared_error(preds, y_test)}") |
| 63 | + print(f"Elapsed: {time() - start:.2f}s\n") |
| 64 | + |
0 commit comments