Skip to content

Commit 4595d9d

Browse files
committed
cleanup
1 parent 5cd08b2 commit 4595d9d

File tree

1 file changed

+0
-31
lines changed

1 file changed

+0
-31
lines changed

rin_pytorch/rin_pytorch.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,6 @@ def convert_image_to(img_type, image):
5858
return image.convert(img_type)
5959
return image
6060

61-
# small helper modules
62-
63-
class Residual(nn.Module):
64-
def __init__(self, fn):
65-
super().__init__()
66-
self.fn = fn
67-
68-
def forward(self, x, *args, **kwargs):
69-
return self.fn(x, *args, **kwargs) + x
70-
7161
# use layernorm without bias, more stable
7262

7363
class LayerNorm(nn.Module):
@@ -259,27 +249,6 @@ def forward(self, x, time = None):
259249

260250
return self.net(x)
261251

262-
class FiLM(nn.Module):
263-
def __init__(
264-
self,
265-
dim,
266-
hidden_dim
267-
):
268-
super().__init__()
269-
self.net = nn.Sequential(
270-
nn.Linear(dim, hidden_dim * 4),
271-
nn.SiLU(),
272-
nn.Linear(hidden_dim * 4, hidden_dim * 2)
273-
)
274-
275-
nn.init.zeros_(self.net[-1].weight)
276-
nn.init.zeros_(self.net[-1].bias)
277-
278-
def forward(self, conditions, hiddens):
279-
scale, shift = self.net(conditions).chunk(2, dim = -1)
280-
scale, shift = map(lambda t: rearrange(t, 'b d -> b 1 d'), (scale, shift))
281-
return hiddens * (scale + 1) + shift
282-
283252
# model
284253

285254
class RIN(nn.Module):

0 commit comments

Comments
 (0)