1+ import os
2+ import nnetsauce as ns
3+ import numpy as np
4+ import jax .numpy as jnp
5+ from nnetsauce .attention import AttentionMechanism
6+ from sklearn .datasets import load_diabetes , fetch_california_housing
7+ from sklearn .model_selection import train_test_split
8+ from sklearn .linear_model import Ridge
9+ from sklearn .ensemble import ExtraTreesRegressor , RandomForestRegressor
10+ from time import time
11+
12+ print (f"\n ----- Running: { os .path .basename (__file__ )} ... ----- \n " )
13+
14+ # Set random seed for reproducibility
15+ np .random .seed (42 )
16+
17+ # Example 1: Univariate time series with temporal attention
18+ print ("=" * 50 )
19+ print ("Example 1: Univariate Time Series" )
20+ print ("=" * 50 )
21+ batch_size , seq_len , input_dim = 32 , 10 , 1
22+ x_univariate = jnp .array (np .random .randn (batch_size , seq_len , input_dim ))
23+
24+ attention = AttentionMechanism (input_dim = input_dim , hidden_dim = 32 , num_heads = 4 )
25+ context , weights = attention (x_univariate , attention_type = 'temporal' )
26+
27+ print (f"Input shape: { x_univariate .shape } " )
28+ print (f"Context shape: { context .shape } " )
29+ print (f"Attention weights shape: { weights .shape } " )
30+ print (f"Sample attention weights (first batch): { np .array (weights [0 ])} " )
31+
32+ # Example 2: Tabular data with feature attention
33+ print ("\n " + "=" * 50 )
34+ print ("Example 2: Tabular Data with Feature Attention" )
35+ print ("=" * 50 )
36+ batch_size , num_features = 32 , 10
37+ x_tabular = jnp .array (np .random .randn (batch_size , num_features ))
38+
39+ attention_tab = AttentionMechanism (input_dim = num_features , hidden_dim = 32 )
40+ output , feature_weights = attention_tab (x_tabular , attention_type = 'feature' )
41+
42+ print (f"Input shape: { x_tabular .shape } " )
43+ print (f"Output shape: { output .shape } " )
44+ print (f"Feature weights shape: { feature_weights .shape } " )
45+ print (f"Feature importance (first batch): { np .array (feature_weights [0 ])} " )
46+
47+ # Example 3: Multi-head attention on sequences
48+ print ("\n " + "=" * 50 )
49+ print ("Example 3: Multi-Head Attention" )
50+ print ("=" * 50 )
51+ batch_size , seq_len , input_dim = 16 , 8 , 16
52+ x_seq = jnp .array (np .random .randn (batch_size , seq_len , input_dim ))
53+
54+ attention_mha = AttentionMechanism (input_dim = input_dim , hidden_dim = 64 , num_heads = 8 )
55+ output_mha , weights_mha = attention_mha (x_seq , attention_type = 'multi_head' )
56+
57+ print (f"Input shape: { x_seq .shape } " )
58+ print (f"Output shape: { output_mha .shape } " )
59+ print (f"Attention weights shape (with heads): { weights_mha .shape } " )
60+
61+ # Example 4: Cross-attention
62+ print ("\n " + "=" * 50 )
63+ print ("Example 4: Cross-Attention" )
64+ print ("=" * 50 )
65+ batch_size = 16
66+ query_seq = jnp .array (np .random .randn (batch_size , 5 , input_dim ))
67+ kv_seq = jnp .array (np .random .randn (batch_size , 10 , input_dim ))
68+
69+ cross_output , cross_weights = attention_mha (
70+ None ,
71+ attention_type = 'cross' ,
72+ query = query_seq ,
73+ key_value = kv_seq
74+ )
75+
76+ print (f"Query shape: { query_seq .shape } " )
77+ print (f"Key-Value shape: { kv_seq .shape } " )
78+ print (f"Cross-attention output shape: { cross_output .shape } " )
79+ print (f"Cross-attention weights shape: { cross_weights .shape } " )
80+
81+ # Example 5: Context Vector Attention
82+ print ("\n " + "=" * 50 )
83+ print ("Example 5: Context Vector Attention" )
84+ print ("=" * 50 )
85+ batch_size , seq_len , input_dim = 32 , 15 , 8
86+ x_context = jnp .array (np .random .randn (batch_size , seq_len , input_dim ))
87+
88+ attention_ctx = AttentionMechanism (input_dim = input_dim , hidden_dim = 64 )
89+ context_output , context_weights = attention_ctx (x_context , attention_type = 'context_vector' )
90+
91+ print (f"Input shape: { x_context .shape } " )
92+ print (f"Context output shape: { context_output .shape } " )
93+ print (f"Context attention weights shape: { context_weights .shape } " )
94+ print (f"Sample context weights (first batch): { np .array (context_weights [0 ])} " )
95+ print (f"\n Note: Context vector attention produces a fixed-size global representation" )
96+ print (f"regardless of input sequence length, making it ideal for classification tasks." )
97+
98+ # Demonstrate JAX's JIT compilation benefit
99+ print ("\n " + "=" * 50 )
100+ print ("JAX Performance Benefits" )
101+ print ("=" * 50 )
102+ print ("All methods are JIT-compiled for fast execution!" )
103+ print ("JAX provides automatic differentiation and GPU/TPU acceleration." )
0 commit comments