forked from ruvnet/RuVector
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgnn_example.rs
More file actions
60 lines (46 loc) · 2.13 KB
/
gnn_example.rs
File metadata and controls
60 lines (46 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
//! Example demonstrating the Ruvector GNN layer usage
use ruvector_gnn::{RuvectorLayer, Linear, MultiHeadAttention, GRUCell, LayerNorm};
fn main() {
println!("=== Ruvector GNN Layer Example ===\n");
// Create a GNN layer
// Parameters: input_dim=128, hidden_dim=256, heads=4, dropout=0.1
let gnn_layer = RuvectorLayer::new(128, 256, 4, 0.1);
// Simulate a node embedding (128 dimensions)
let node_embedding = vec![0.5; 128];
// Simulate 3 neighbor embeddings
let neighbor_embeddings = vec![
vec![0.3; 128],
vec![0.7; 128],
vec![0.5; 128],
];
// Edge weights (e.g., inverse distances)
let edge_weights = vec![0.8, 0.6, 0.4];
// Forward pass through the GNN layer
let updated_embedding = gnn_layer.forward(&node_embedding, &neighbor_embeddings, &edge_weights);
println!("Input dimension: {}", node_embedding.len());
println!("Output dimension: {}", updated_embedding.len());
println!("Number of neighbors: {}", neighbor_embeddings.len());
println!("\n✓ GNN layer forward pass successful!");
// Demonstrate individual components
println!("\n=== Individual Components ===\n");
// 1. Linear layer
let linear = Linear::new(128, 64);
let linear_output = linear.forward(&node_embedding);
println!("Linear layer: 128 -> {}", linear_output.len());
// 2. Layer normalization
let layer_norm = LayerNorm::new(128, 1e-5);
let normalized = layer_norm.forward(&node_embedding);
println!("LayerNorm output dimension: {}", normalized.len());
// 3. Multi-head attention
let attention = MultiHeadAttention::new(128, 4);
let keys = neighbor_embeddings.clone();
let values = neighbor_embeddings.clone();
let attention_output = attention.forward(&node_embedding, &keys, &values);
println!("Multi-head attention output: {}", attention_output.len());
// 4. GRU cell
let gru = GRUCell::new(128, 256);
let hidden_state = vec![0.0; 256];
let new_hidden = gru.forward(&node_embedding, &hidden_state);
println!("GRU cell output dimension: {}", new_hidden.len());
println!("\n✓ All components working correctly!");
}