Skip to content

关于diffusers.models.unets.unet_2d_blocks中的CrossAttnDownBlock2D #10580

@Charging-up

Description

@Charging-up

CrossAttnDownBlock2D类是下采样模块的一部分。类中sample在forward方法中依次经过self.resnets模块和self.attentions模块的处理,然而在__init__方法中,self.resnets模块和self.attentions模块的定义和调用顺序是相反的:

        self.attentions = nn.ModuleList(attentions)
        self.resnets = nn.ModuleList(resnets)

这导致我查看网络结构时先出现了attentions模块,而后是resnets模块,对于想要理解网络架构的人来说,这样不是很友好。
其中一个CrossAttnDownBlock2D模块:

CrossAttnDownBlock2D(
  (attentions): ModuleList(
    (0-1): 2 x Transformer2DModel(
      (norm): GroupNorm(32, 640, eps=1e-06, affine=True)
      (proj_in): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
      (transformer_blocks): ModuleList(
        (0): BasicTransformerBlock(
          (norm1): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (attn1): Attention(
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=640, out_features=640, bias=False)
            (to_v): Linear(in_features=640, out_features=640, bias=False)
            (to_out): ModuleList(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm2): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (attn2): Attention(
            (to_q): Linear(in_features=640, out_features=640, bias=False)
            (to_k): Linear(in_features=768, out_features=640, bias=False)
            (to_v): Linear(in_features=768, out_features=640, bias=False)
            (to_out): ModuleList(
              (0): Linear(in_features=640, out_features=640, bias=True)
              (1): Dropout(p=0.0, inplace=False)
            )
          )
          (norm3): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
          (ff): FeedForward(
            (net): ModuleList(
              (0): GEGLU(
                (proj): Linear(in_features=640, out_features=5120, bias=True)
              )
              (1): Dropout(p=0.0, inplace=False)
              (2): Linear(in_features=2560, out_features=640, bias=True)
            )
          )
        )
      )
      (proj_out): Conv2d(640, 640, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (resnets): ModuleList(
    (0): ResnetBlock2D(
      (norm1): GroupNorm(32, 320, eps=1e-05, affine=True)
      (conv1): Conv2d(320, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)
      (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
      (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (nonlinearity): SiLU()
      (conv_shortcut): Conv2d(320, 640, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): ResnetBlock2D(
      (norm1): GroupNorm(32, 640, eps=1e-05, affine=True)
      (conv1): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (time_emb_proj): Linear(in_features=1280, out_features=640, bias=True)
      (norm2): GroupNorm(32, 640, eps=1e-05, affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
      (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (nonlinearity): SiLU()
    )
  )
  (downsamplers): ModuleList(
    (0): Downsample2D(
      (conv): Conv2d(640, 640, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
  )
)

可以看到这两个模块在网络架构中显示的顺序和调用顺序是相反的。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions