Skip to content

Commit 92ae025

Browse files
add attention mechanisms
1 parent 58bee85 commit 92ae025

File tree

9 files changed

+666
-10
lines changed

9 files changed

+666
-10
lines changed

examples/attention.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import os
2+
import nnetsauce as ns
3+
import numpy as np
4+
import jax.numpy as jnp
5+
from nnetsauce.attention import AttentionMechanism
6+
from sklearn.datasets import load_diabetes, fetch_california_housing
7+
from sklearn.model_selection import train_test_split
8+
from sklearn.linear_model import Ridge
9+
from sklearn.ensemble import ExtraTreesRegressor, RandomForestRegressor
10+
from time import time
11+
12+
print(f"\n ----- Running: {os.path.basename(__file__)}... ----- \n")
13+
14+
# Set random seed for reproducibility
15+
np.random.seed(42)
16+
17+
# Example 1: Univariate time series with temporal attention
18+
print("=" * 50)
19+
print("Example 1: Univariate Time Series")
20+
print("=" * 50)
21+
batch_size, seq_len, input_dim = 32, 10, 1
22+
x_univariate = jnp.array(np.random.randn(batch_size, seq_len, input_dim))
23+
24+
attention = AttentionMechanism(input_dim=input_dim, hidden_dim=32, num_heads=4)
25+
context, weights = attention(x_univariate, attention_type='temporal')
26+
27+
print(f"Input shape: {x_univariate.shape}")
28+
print(f"Context shape: {context.shape}")
29+
print(f"Attention weights shape: {weights.shape}")
30+
print(f"Sample attention weights (first batch): {np.array(weights[0])}")
31+
32+
# Example 2: Tabular data with feature attention
33+
print("\n" + "=" * 50)
34+
print("Example 2: Tabular Data with Feature Attention")
35+
print("=" * 50)
36+
batch_size, num_features = 32, 10
37+
x_tabular = jnp.array(np.random.randn(batch_size, num_features))
38+
39+
attention_tab = AttentionMechanism(input_dim=num_features, hidden_dim=32)
40+
output, feature_weights = attention_tab(x_tabular, attention_type='feature')
41+
42+
print(f"Input shape: {x_tabular.shape}")
43+
print(f"Output shape: {output.shape}")
44+
print(f"Feature weights shape: {feature_weights.shape}")
45+
print(f"Feature importance (first batch): {np.array(feature_weights[0])}")
46+
47+
# Example 3: Multi-head attention on sequences
48+
print("\n" + "=" * 50)
49+
print("Example 3: Multi-Head Attention")
50+
print("=" * 50)
51+
batch_size, seq_len, input_dim = 16, 8, 16
52+
x_seq = jnp.array(np.random.randn(batch_size, seq_len, input_dim))
53+
54+
attention_mha = AttentionMechanism(input_dim=input_dim, hidden_dim=64, num_heads=8)
55+
output_mha, weights_mha = attention_mha(x_seq, attention_type='multi_head')
56+
57+
print(f"Input shape: {x_seq.shape}")
58+
print(f"Output shape: {output_mha.shape}")
59+
print(f"Attention weights shape (with heads): {weights_mha.shape}")
60+
61+
# Example 4: Cross-attention
62+
print("\n" + "=" * 50)
63+
print("Example 4: Cross-Attention")
64+
print("=" * 50)
65+
batch_size = 16
66+
query_seq = jnp.array(np.random.randn(batch_size, 5, input_dim))
67+
kv_seq = jnp.array(np.random.randn(batch_size, 10, input_dim))
68+
69+
cross_output, cross_weights = attention_mha(
70+
None,
71+
attention_type='cross',
72+
query=query_seq,
73+
key_value=kv_seq
74+
)
75+
76+
print(f"Query shape: {query_seq.shape}")
77+
print(f"Key-Value shape: {kv_seq.shape}")
78+
print(f"Cross-attention output shape: {cross_output.shape}")
79+
print(f"Cross-attention weights shape: {cross_weights.shape}")
80+
81+
# Example 5: Context Vector Attention
82+
print("\n" + "=" * 50)
83+
print("Example 5: Context Vector Attention")
84+
print("=" * 50)
85+
batch_size, seq_len, input_dim = 32, 15, 8
86+
x_context = jnp.array(np.random.randn(batch_size, seq_len, input_dim))
87+
88+
attention_ctx = AttentionMechanism(input_dim=input_dim, hidden_dim=64)
89+
context_output, context_weights = attention_ctx(x_context, attention_type='context_vector')
90+
91+
print(f"Input shape: {x_context.shape}")
92+
print(f"Context output shape: {context_output.shape}")
93+
print(f"Context attention weights shape: {context_weights.shape}")
94+
print(f"Sample context weights (first batch): {np.array(context_weights[0])}")
95+
print(f"\nNote: Context vector attention produces a fixed-size global representation")
96+
print(f"regardless of input sequence length, making it ideal for classification tasks.")
97+
98+
# Demonstrate JAX's JIT compilation benefit
99+
print("\n" + "=" * 50)
100+
print("JAX Performance Benefits")
101+
print("=" * 50)
102+
print("All methods are JIT-compiled for fast execution!")
103+
print("JAX provides automatic differentiation and GPU/TPU acceleration.")

examples/conformal_simulation2.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import os
2+
import nnetsauce as ns
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
import warnings
6+
from sklearn.datasets import fetch_california_housing, load_diabetes
7+
from sklearn.model_selection import train_test_split
8+
from sklearn.linear_model import BayesianRidge, ARDRegression, RidgeCV
9+
from sklearn.ensemble import ExtraTreesRegressor
10+
from time import time
11+
12+
13+
# # 2 - Useful plotting functions
14+
15+
16+
17+
warnings.filterwarnings('ignore')
18+
19+
split_color = 'green'
20+
split_color2 = 'tomato'
21+
local_color = 'gray'
22+
23+
def plot_func(x,
24+
y,
25+
y_u=None,
26+
y_l=None,
27+
pred=None,
28+
shade_color="",
29+
method_name="",
30+
title=""):
31+
32+
fig = plt.figure()
33+
34+
plt.plot(x, y, 'k.', alpha=.3, markersize=10,
35+
fillstyle='full', label=u'Test set observations')
36+
37+
if (y_u is not None) and (y_l is not None):
38+
plt.fill(np.concatenate([x, x[::-1]]),
39+
np.concatenate([y_u, y_l[::-1]]),
40+
alpha=.3, fc=shade_color, ec='None',
41+
label = method_name + ' Prediction interval')
42+
43+
if pred is not None:
44+
plt.plot(x, pred, 'k--', lw=2, alpha=0.9,
45+
label=u'Predicted value')
46+
47+
#plt.ylim([-2.5, 7])
48+
plt.xlabel('$X$')
49+
plt.ylabel('$Y$')
50+
plt.legend(loc='upper right')
51+
plt.title(title)
52+
53+
plt.show()
54+
55+
56+
# # 3 - Examples of use
57+
58+
59+
60+
data = fetch_california_housing()
61+
X = data.data
62+
y= data.target
63+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = .2, random_state = 123)
64+
65+
66+
# ## RidgeCV
67+
"""
68+
- 'bootstrap': Bootstrap resampling.
69+
- 'kde': Kernel Density Estimation.
70+
- 'ecdf': Empirical CDF-based sampling.
71+
- 'permutation': Permutation resampling.
72+
- 'smooth_bootstrap': Smoothed bootstrap with added noise.
73+
"""
74+
75+
for type_pi in ('bootstrap', 'kde', 'ecdf', 'permutation', 'smooth_bootstrap'):
76+
print(f"\n\n### type_pi = {type_pi} ###\n")
77+
regr1 = ns.PredictionInterval(RidgeCV(),
78+
replications=100,
79+
type_pi=type_pi) # 5 hidden nodes, ReLU activation function
80+
regr1.fit(X_train, y_train)
81+
start = time()
82+
preds1 = regr1.predict(X_test, return_pi=True)
83+
print(f"Elapsed: {time() - start}s")
84+
print(f"coverage_rate conformalized QRNN RidgeCV: {np.mean((preds1[2]<=y_test)*(preds1[3]>=y_test))}")
85+
print(f"predictive simulations: {preds1[1]}")
86+
87+
max_idx = 50
88+
plot_func(x = range(max_idx),
89+
y = y_test[0:max_idx],
90+
y_u = preds1.upper[0:max_idx],
91+
y_l = preds1.lower[0:max_idx],
92+
pred = preds1.mean[0:max_idx],
93+
shade_color=split_color2,
94+
title = f"conformalized QRNN RidgeCV ({max_idx} first points in test set)")
95+
96+

nnetsauce/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .attention import AttentionMechanism
12
from .base.base import Base
23
from .base.baseRegressor import BaseRegressor
34
from .boosting.adaBoostClassifier import AdaBoostClassifier
@@ -42,6 +43,7 @@
4243

4344
__all__ = [
4445
"AdaBoostClassifier",
46+
"AttentionMechanism",
4547
"Base",
4648
"BaseRegressor",
4749
"BayesianRVFLRegressor",

nnetsauce/attention/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .attention import AttentionMechanism
2+
3+
__all__ = ["AttentionMechanism"]

0 commit comments

Comments
 (0)