-
Notifications
You must be signed in to change notification settings - Fork 60
Expand file tree
/
Copy pathembeddings.py
More file actions
221 lines (182 loc) · 7.82 KB
/
embeddings.py
File metadata and controls
221 lines (182 loc) · 7.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
# (C) Copyright 2025 WeatherGenerator contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
from functools import partial
import numpy as np
import torch
from weathergen.model.attention import MultiSelfAttentionHead
from weathergen.model.layers import MLP
# from weathergen.model.mlp import MLP
from weathergen.model.norms import RMSNorm
from weathergen.model.positional_encoding import positional_encoding_harmonic
from weathergen.model.utils import cond_checkpoint
class StreamEmbedTransformer(torch.nn.Module):
def __init__(
self,
cf,
mode,
num_tokens,
token_size,
num_channels,
dim_embed,
dim_out,
num_blocks,
num_heads,
dropout_rate=0.0,
norm_type="LayerNorm",
unembed_mode="full",
stream_name="stream_embed",
):
"""Constructor
unembed_mode : { 'full' , 'block'}
full : monolithic (and correspondingly large) unembedding network that maps from
(num_tokens x dim_embed) to dim_out, allowing for mixing between channels/columns
block : per-channel/column unembedding network
(which is hence a block-sparse form of full)
"""
super(StreamEmbedTransformer, self).__init__()
self.name = f"StreamEmbedder_{stream_name}"
self.mode = mode
self.num_tokens = num_tokens
self.token_size = token_size
self.num_channels = num_channels
self.dim_in = token_size if mode == "channels" else num_channels
self.dim_embed = dim_embed
self.dim_out = dim_out
self.num_blocks = num_blocks
self.num_heads = num_heads
self.unembed_mode = unembed_mode
self.cf = cf
norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm
self.layers = torch.nn.ModuleList()
for _ in range(self.num_blocks):
self.layers.append(
MultiSelfAttentionHead(
self.dim_embed,
self.num_heads,
dropout_rate=dropout_rate,
with_qk_lnorm=True,
with_flash=True,
)
)
self.layers.append(
MLP(
self.dim_embed,
self.dim_embed,
hidden_factor=2,
dropout_rate=dropout_rate,
with_residual=True,
)
)
if mode == "channels":
self.embed = torch.nn.Linear(self.dim_in, self.dim_embed)
if self.unembed_mode == "full":
self.ln_final = norm(num_channels * self.dim_embed, eps=1e-03)
self.unembed = torch.nn.Linear(
num_channels * self.dim_embed,
self.num_tokens * self.dim_out,
)
elif self.unembed_mode == "block":
dim_out = (self.num_tokens * self.dim_out) // num_channels
self.unembed = torch.nn.ModuleList(
[torch.nn.Linear(dim_embed, dim_out) for _ in range(num_channels)]
# [
# torch.nn.Sequential(
# torch.nn.Linear(dim_embed, max(dim_embed//2,4*dim_out)),
# torch.nn.GELU(),
# torch.nn.Linear(max(dim_embed//2,4*dim_out), dim_out)
# ) for _ in range(num_channels)
# ]
)
self.ln_final = torch.nn.ModuleList(
[norm(dim_embed, eps=1e-6) for _ in range(num_channels)]
)
else:
raise ValueError(f"Unknown unembed mode: {unembed_mode}")
elif mode == "columns":
self.embed = torch.nn.Linear(self.dim_in, self.dim_embed)
assert self.unembed_mode == "block" # only supported mode at the moment
# padding needed if the unembedded columns cannot be concatenated to dim_out (e.g GPSRO)
self.pad = self.dim_out % token_size
self.out_pad = torch.nn.Parameter(torch.zeros(self.pad), requires_grad=False)
self.unembed = torch.nn.Linear(
self.dim_embed,
self.num_tokens * (self.dim_out // token_size),
)
self.ln_final = norm(dim_out, eps=1e-6)
# TODO: factorization when sqrt is not int
dim1 = int(np.sqrt(dim_out))
assert dim1 * dim1 == dim_out
self.unembed1 = torch.nn.Linear(self.dim_embed, dim1)
self.unembed_nonlin = torch.nn.GELU()
self.unembed2 = torch.nn.Linear(self.token_size, dim1)
else:
raise ValueError(f"Unknown mode: {mode}")
self.dropout_final = torch.nn.Dropout(0.1)
self.checkpoint_stream_embed = partial(
cond_checkpoint, self.cf.get("stream_embed_gradient_checkpoint_enabled", True)
)
def forward_channels(self, x_in):
peh = positional_encoding_harmonic
# embed provided input data
x = peh(
self.checkpoint_stream_embed(self.embed, x_in.transpose(-2, -1), use_reentrant=False)
)
for layer in self.layers:
x = self.checkpoint_stream_embed(layer, x, use_reentrant=False)
# read out
if self.unembed_mode == "full":
out = self.checkpoint_stream_embed(
self.unembed, self.ln_final(x.flatten(-2, -1)), use_reentrant=False
)
elif self.unembed_mode == "block":
out = [
self.checkpoint_stream_embed(ue, ln(x[:, i]), use_reentrant=False)
for i, (ue, ln) in enumerate(zip(self.unembed, self.ln_final, strict=True))
]
out = torch.stack(out, dim=1).flatten(-2, -1)
else:
raise ValueError(f"Unknown unembed mode: {self.unembed_mode}")
if out.shape[-1] < self.dim_out:
out = torch.nn.functional.pad(out, [0, self.dim_out - out.shape[-1]], value=0.0)
# final reshape
out = self.dropout_final(out.reshape(-1, self.num_tokens, self.dim_out))
return out
def forward_columns(self, x_in):
# embed provided input data
x = positional_encoding_harmonic(
self.checkpoint_stream_embed(self.embed, x_in, use_reentrant=False)
)
for layer in self.layers:
x = self.checkpoint_stream_embed(layer, x, use_reentrant=False)
out = self.checkpoint_stream_embed(self.unembed1, x, use_reentrant=False)
out = self.unembed_nonlin(out)
out = self.checkpoint_stream_embed(
self.unembed2, out.transpose(-2, -1), use_reentrant=False
)
out = out.flatten(-2, -1).unsqueeze(1)
# final normalize and dropout
out = self.dropout_final(self.ln_final(out))
return out.to(torch.float16)
def forward(self, x_in):
if self.mode == "channels":
return self.forward_channels(x_in)
elif self.mode == "columns":
return self.forward_columns(x_in)
else:
raise ValueError(f"Unknown mode {self.mode}")
class StreamEmbedLinear(torch.nn.Module):
def __init__(self, dim_in, dim_out, stream_name="stream_embed"):
"""Constructor"""
super(StreamEmbedLinear, self).__init__()
self.name = f"StreamEmbedder_{stream_name}"
self.layer = torch.nn.Linear(dim_in, dim_out)
def forward(self, x):
# x = checkpoint( self.layer, x.flatten( -2, -1), use_reentrant=True)
x = self.layer(x.flatten(-2, -1))
return x