1+ import torch
2+ import torch .nn as nn
3+
4+ class SIRENLayer (nn .Module ):
5+ def __init__ (self , in_features , out_features , omega_0 = 30 ):
6+ super ().__init__ ()
7+ self .linear = nn .Linear (in_features , out_features )
8+ self .omega_0 = omega_0
9+
10+ def forward (self , x ):
11+ return torch .sin (self .omega_0 * self .linear (x ))
12+
13+ class Learnable_Positional_Encoding (nn .Module ):
14+ def __init__ (self , dim , max_len = 1024 ):
15+ super (Learnable_Positional_Encoding , self ).__init__ ()
16+ self .pos_embedding = nn .Embedding (max_len , dim )
17+
18+ def forward (self , x ):
19+ # x shape: (batch_size, seq_len, dim)
20+ seq_len = x .size (1 )
21+ position_ids = torch .arange (seq_len , dtype = torch .long , device = x .device )
22+ position_ids = position_ids .unsqueeze (0 ).expand_as (x [:, :, 0 ]) # (batch_size, seq_len)
23+ pos_embeddings = self .pos_embedding (position_ids )
24+ return x + pos_embeddings
25+
26+ class Position_MLP (torch .nn .Module ):
27+ def __init__ (self , in_channels = 256 , hidden_channels = [128 , 64 ], out_channels = 32 , dropout = 0.1 ):
28+ super (Position_MLP , self ).__init__ ()
29+ layers = []
30+ layers .append ( nn .Linear (in_channels , hidden_channels [0 ]) )
31+ layers .append ( nn .GELU () )
32+ layers .append ( nn .Dropout (dropout ) )
33+ for i in range (1 , len (hidden_channels )):
34+ layers .append ( nn .Linear (hidden_channels [i - 1 ], hidden_channels [i ]) )
35+ layers .append ( nn .GELU () )
36+ layers .append ( nn .Dropout (dropout ) )
37+ layers .append ( nn .Linear (hidden_channels [- 1 ], out_channels ) )
38+ self .mlp = nn .Sequential ( * layers )
39+
40+ def forward (self , x ):
41+ return self .mlp (x )
0 commit comments