Skip to content

Commit aa59d98

Browse files
use only JAX for ElasticNet2Regressor Pt.2
1 parent 8fcd509 commit aa59d98

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

examples/elasticnet2regressor.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,31 @@
3434

3535
print(f"RMSE for {datasets_names[i]} : {root_mean_squared_error(preds, y_test)}")
3636

37+
# Example 1: Adam optimizer (Optax)
38+
regr = ns.ElasticNet2Regressor(
39+
solver="adam", # Optax optimizer name
40+
learning_rate=0.01, # Learning rate
41+
max_iter=1000, # Max iterations
42+
tol=1e-4, # Tolerance for early stopping
43+
verbose=True # Print progress
44+
)
45+
start = time()
46+
regr.fit(X_train, y_train)
47+
preds = regr.predict(X_test)
48+
print(f"Adam - RMSE for {datasets_names[i]}: {root_mean_squared_error(preds, y_test)}")
49+
print(f"Elapsed: {time() - start:.2f}s\n")
50+
51+
# Example 2: SGD with momentum (Optax)
52+
regr = ns.ElasticNet2Regressor(
53+
solver="sgd", # Stochastic Gradient Descent
54+
learning_rate=0.001, # Smaller learning rate for SGD
55+
max_iter=1500,
56+
type_loss='quantile', # Quantile regression
57+
quantile=0.5 # Median regression
58+
)
59+
start = time()
60+
regr.fit(X_train, y_train)
61+
preds = regr.predict(X_test)
62+
print(f"SGD (Quantile) - RMSE for {datasets_names[i]}: {root_mean_squared_error(preds, y_test)}")
63+
print(f"Elapsed: {time() - start:.2f}s\n")
64+

0 commit comments

Comments
 (0)