Skip to content

Commit d541a69

Browse files
author
dmoi
committed
fix geometry functions
1 parent 45bc8de commit d541a69

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

foldtree2/src/layers.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class SIRENLayer(nn.Module):
5+
def __init__(self, in_features, out_features, omega_0=30):
6+
super().__init__()
7+
self.linear = nn.Linear(in_features, out_features)
8+
self.omega_0 = omega_0
9+
10+
def forward(self, x):
11+
return torch.sin(self.omega_0 * self.linear(x))
12+
13+
class Learnable_Positional_Encoding(nn.Module):
14+
def __init__(self, dim, max_len=1024):
15+
super(Learnable_Positional_Encoding, self).__init__()
16+
self.pos_embedding = nn.Embedding(max_len, dim)
17+
18+
def forward(self, x):
19+
# x shape: (batch_size, seq_len, dim)
20+
seq_len = x.size(1)
21+
position_ids = torch.arange(seq_len, dtype=torch.long, device=x.device)
22+
position_ids = position_ids.unsqueeze(0).expand_as(x[:, :, 0]) # (batch_size, seq_len)
23+
pos_embeddings = self.pos_embedding(position_ids)
24+
return x + pos_embeddings
25+
26+
class Position_MLP(torch.nn.Module):
27+
def __init__(self, in_channels=256, hidden_channels=[128, 64], out_channels=32, dropout=0.1):
28+
super(Position_MLP, self).__init__()
29+
layers = []
30+
layers.append( nn.Linear(in_channels, hidden_channels[0]) )
31+
layers.append( nn.GELU() )
32+
layers.append( nn.Dropout(dropout) )
33+
for i in range(1, len(hidden_channels)):
34+
layers.append( nn.Linear(hidden_channels[i-1], hidden_channels[i]) )
35+
layers.append( nn.GELU() )
36+
layers.append( nn.Dropout(dropout) )
37+
layers.append( nn.Linear(hidden_channels[-1], out_channels) )
38+
self.mlp = nn.Sequential( *layers )
39+
40+
def forward(self, x):
41+
return self.mlp(x)

0 commit comments

Comments
 (0)