-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Open
Description
CrossAttnDownBlock2D类是下采样模块的一部分。类中sample在forward方法中依次经过self.resnets模块和self.attentions模块的处理,然而在__init__方法中,self.resnets模块和self.attentions模块的定义和调用顺序是相反的:
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
这导致我查看网络结构时先出现了attentions模块,而后是resnets模块,对于想要理解网络架构的人来说,这样不是很友好。
其中一个CrossAttnDownBlock2D模块:
CrossAttnDownBlock2D(
(attentions): ModuleList(
(0-1): 2 x Transformer2DModel(
(norm): GroupNorm(32, 640, eps=1e-06, affine=True)
(proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
(transformer_blocks): ModuleList(
(0): BasicTransformerBlock(
(norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
(attn1): Attention(
(to_q): Linear(in_features=640, out_features=640, bias=False)
(to_k): Linear(in_features=640, out_features=640, bias=False)
(to_v): Linear(in_features=640, out_features=640, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=640, out_features=640, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
(attn2): Attention(
(to_q): Linear(in_features=640, out_features=640, bias=False)
(to_k): Linear(in_features=768, out_features=640, bias=False)
(to_v): Linear(in_features=768, out_features=640, bias=False)
(to_out): ModuleList(
(0): Linear(in_features=640, out_features=640, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
(norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
(ff): FeedForward(
(net): ModuleList(
(0): GEGLU(
(proj): Linear(in_features=640, out_features=5120, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=2560, out_features=640, bias=True)
)
)
)
)
(proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 320, eps=1e-05, affine=True)
(conv1): Conv2d(320, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)
(norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(320, 640, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 640, eps=1e-05, affine=True)
(conv1): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)
(norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
(downsamplers): ModuleList(
(0): Downsample2D(
(conv): Conv2d(640, 640, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)
)
)
可以看到这两个模块在网络架构中显示的顺序和调用顺序是相反的。
Metadata
Metadata
Assignees
Labels
No labels