Skip to content

Commit c8113b8

Browse files
committed
update
1 parent 0962c2b commit c8113b8

File tree

2 files changed

+41
-26
lines changed

2 files changed

+41
-26
lines changed

LICENSE

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ National Laboratory, U.S. Dept. of Energy nor the names of its contributors
1818
may be used to endorse or promote products derived from this software
1919
without specific prior written permission.
2020

21-
2221
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
2322
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
2423
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

StFT_3D.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4+
from einops import rearrange
45
from model_utils import TransformerLayer, get_2d_sincos_pos_embed
56

67

@@ -34,19 +35,19 @@ def __init__(
3435
self.pos_embed = nn.Parameter(
3536
torch.randn(1, num_patches, dim), requires_grad=False
3637
)
37-
self.pos_embed_f = nn.Parameter(
38+
self.pos_embed_fno = nn.Parameter(
3839
torch.randn(1, num_patches, dim), requires_grad=False
3940
)
4041
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], grid_size)
41-
pos_embed_f = get_2d_sincos_pos_embed(self.pos_embed_f.shape[-1], grid_size)
42+
pos_embed_fno = get_2d_sincos_pos_embed(self.pos_embed_fno.shape[-1], grid_size)
4243
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
43-
self.pos_embed_f.data.copy_(
44-
torch.from_numpy(pos_embed_f).float().unsqueeze(0)
44+
self.pos_embed_fno.data.copy_(
45+
torch.from_numpy(pos_embed_fno).float().unsqueeze(0)
4546
)
4647
self.encoder_layers = nn.ModuleList(
4748
[TransformerLayer(dim, num_heads, mlp_dim, act) for _ in range(depth)]
4849
)
49-
self.encoder_layers_f = nn.ModuleList(
50+
self.encoder_layers_fno = nn.ModuleList(
5051
[TransformerLayer(dim, num_heads, mlp_dim, act) for _ in range(depth)]
5152
)
5253
self.head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, out_dim))
@@ -63,26 +64,34 @@ def forward(self, x):
6364
n, l, _, ph, pw = x.shape
6465
x_or = x[:, :, : self.cond_time * self.freq_in_channels]
6566
x_added = x[:, :, (self.cond_time * self.freq_in_channels) :]
66-
x_or = x_or.permute(0, 1, 3, 4, 2).view(n, l, ph, pw, self.cond_time, self.freq_in_channels)
67+
x_or = rearrange(
68+
x_or,
69+
"n l (t v) ph pw -> n l ph pw t v",
70+
t=self.cond_time,
71+
v=self.freq_in_channels,
72+
)
6773
grid_dup = x_or[:, :, :, :, :1, -2:].repeat(1, 1, 1, 1, self.layer_indx, 1)
68-
69-
x_added = x_added.permute(0, 1, 3, 4, 2).view(n, l, ph, pw, self.layer_indx, self.freq_in_channels - 2)
74+
x_added = rearrange(
75+
x_added,
76+
"n l (t v) ph pw -> n l ph pw t v",
77+
t=self.layer_indx,
78+
v=self.freq_in_channels - 2,
79+
)
7080
x_added = torch.cat((x_added, grid_dup), axis=-1)
7181
x = torch.cat((x_or, x_added), axis=-2)
7282
x = self.p(x)
73-
v, t = x.shape[-1], x.shape[-2]
74-
x = x.permute(0, 1, 5, 4, 2, 3).view(n * l, v, t, ph, pw)
83+
x = rearrange(x, "n l ph pw t v -> (n l) v t ph pw")
7584
x_ft = torch.fft.rfftn(x, dim=[2, 3, 4])[
7685
:, :, :, : self.modes[0], : self.modes[1]
7786
]
7887
x_ft_real = (x_ft.real).flatten(1)
7988
x_ft_imag = (x_ft.imag).flatten(1)
80-
x_ft_real = x_ft_real.view(n, l, -1)
81-
x_ft_imag = x_ft_imag.view(n, l, -1)
89+
x_ft_real = rearrange(x_ft_real, "(n l) D -> n l D", n=n, l=l)
90+
x_ft_imag = rearrange(x_ft_imag, "(n l) D -> n l D", n=n, l=l)
8291
x_ft_real_imag = torch.cat((x_ft_real, x_ft_imag), axis=-1)
8392
x = self.linear(x_ft_real_imag)
84-
x = x + self.pos_embed_f
85-
for layer in self.encoder_layers_f:
93+
x = x + self.pos_embed_fno
94+
for layer in self.encoder_layers_fno:
8695
x = layer(x)
8796
x_real, x_imag = self.q(x).split(
8897
self.modes[0] * self.modes[1] * self.lift_channel, dim=-1
@@ -101,19 +110,20 @@ def forward(self, x):
101110
)
102111
out_ft[:, :, :, : self.modes[0], : self.modes[1]] = x_complex
103112
x = torch.fft.irfftn(out_ft, s=(1, ph, pw))
104-
x = x.permute(0, 3, 4, 1, 2).view(n * l, ph, pw, -1)
113+
x = rearrange(x, "(n l) v t ph pw -> (n l) ph pw (v t)", n=n, l=l, t=1)
105114
x = self.down(x)
106-
c = x.shape[-1]
107-
x_f = x.permute(0, 3, 1, 2).view(n, l, c, ph, pw)
115+
x_fno = rearrange(x, "(n l) ph pw c -> n l c ph pw", n=n, l=l)
108116
x = x_copy
109117
_, _, _, ph, pw = x.shape
110118
x = x.flatten(2)
111119
x = self.token_embed(x) + self.pos_embed
112120
for layer in self.encoder_layers:
113121
x = layer(x)
114122
x = self.head(x)
115-
x = x.view(n, l, self.out_channel, ph, pw)
116-
x = x + x_f
123+
x = rearrange(
124+
x, "n l (c ph pw) -> n l c ph pw", c=self.out_channel, ph=ph, pw=pw
125+
)
126+
x = x + x_fno
117127
return x
118128

119129

@@ -206,7 +216,7 @@ def __init__(
206216
def forward(self, x, grid):
207217
grid_dup = grid[None, :, :, :].repeat(x.shape[0], x.shape[1], 1, 1, 1)
208218
x = torch.cat((x, grid_dup), axis=2)
209-
x = x.view(x.shape[0], x.shape[1] * x.shape[2], x.shape[3], x.shape[4])
219+
x = rearrange(x, "B L C H W -> B (L C) H W")
210220
layer_outputs = []
211221
patches = x
212222
restore_params = []
@@ -240,20 +250,26 @@ def forward(self, x, grid):
240250
)
241251

242252
patches = patches.unfold(2, p1, step_h).unfold(3, p2, step_w)
243-
n, c, h, w, ph, pw = x.shape
244-
patches = patches.permute(0, 2, 3, 1, 4, 5).view(n, h*w, c, ph, pw)
253+
patches = rearrange(patches, "n c h w ph pw -> n (h w) c ph pw")
254+
245255
processed_patches = self.blocks[depth](patches)
246256

247-
patches = processed_patches.permute(0, 2, 1, 3, 4).view(n, c, h, w, ph, pw)
257+
patches = rearrange(
258+
processed_patches, "n (h w) c ph pw -> n c h w ph pw", h=h, w=w
259+
)
260+
248261
output = F.fold(
249-
torch.reshape(patches.permute(0, 1, 4, 5, 2, 3),(n, c * ph * pw, h * w)),
262+
rearrange(patches, "n c h w ph pw -> n (c ph pw) (h w)"),
250263
output_size=(H_pad, W_pad),
251264
kernel_size=(p1, p2),
252265
stride=(step_h, step_w),
253266
)
254267

255268
overlap_count = F.fold(
256-
torch.reshape(torch.ones_like(patches).permute(0, 1, 4, 5, 2, 3),(n, c * ph * pw, h * w)),
269+
rearrange(
270+
torch.ones_like(patches),
271+
"n c h w ph pw -> n (c ph pw) (h w)",
272+
),
257273
output_size=(H_pad, W_pad),
258274
kernel_size=(p1, p2),
259275
stride=(step_h, step_w),

0 commit comments

Comments
 (0)