|
15 | 15 | ) |
16 | 16 |
|
17 | 17 | class ControlNetConditioningLayer(nn.Module): |
18 | | - def __init__(self, channels = (3, 16, 32, 96, 256, 320)): |
| 18 | + def __init__(self, channels = (3, 16, 32, 96, 256, 320), device = "cuda:0", dtype=torch.float16): |
19 | 19 | super().__init__() |
20 | 20 | self.blocks = torch.nn.ModuleList([]) |
21 | | - self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1)) |
| 21 | + self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1, device=device, dtype=dtype)) |
22 | 22 | self.blocks.append(torch.nn.SiLU()) |
23 | 23 | for i in range(1, len(channels) - 2): |
24 | | - self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1)) |
| 24 | + self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1, device=device, dtype=dtype)) |
25 | 25 | self.blocks.append(torch.nn.SiLU()) |
26 | | - self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2)) |
| 26 | + self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2, device=device, dtype=dtype)) |
27 | 27 | self.blocks.append(torch.nn.SiLU()) |
28 | | - self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1)) |
| 28 | + self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1, device=device, dtype=dtype)) |
29 | 29 |
|
30 | 30 | def forward(self, conditioning): |
31 | 31 | for block in self.blocks: |
@@ -496,64 +496,64 @@ def __init__( |
496 | 496 | ): |
497 | 497 | super().__init__() |
498 | 498 | self.time_embedding = TimestepEmbeddings(dim_in=320, dim_out=1280, device=device, dtype=dtype) |
499 | | - self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1) |
| 499 | + self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1, device=device, dtype=dtype) |
500 | 500 |
|
501 | | - self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320)) |
| 501 | + self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320), device=device, dtype=dtype) |
502 | 502 |
|
503 | 503 | self.blocks = torch.nn.ModuleList([ |
504 | 504 | # CrossAttnDownBlock2D |
505 | | - ResnetBlock(320, 320, 1280), |
506 | | - AttentionBlock(8, 40, 320, 1, 768), |
| 505 | + ResnetBlock(320, 320, 1280, device=device, dtype=dtype), |
| 506 | + AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype), |
507 | 507 | PushBlock(), |
508 | | - ResnetBlock(320, 320, 1280), |
509 | | - AttentionBlock(8, 40, 320, 1, 768), |
| 508 | + ResnetBlock(320, 320, 1280, device=device, dtype=dtype), |
| 509 | + AttentionBlock(8, 40, 320, 1, 768, device=device, dtype=dtype), |
510 | 510 | PushBlock(), |
511 | | - DownSampler(320), |
| 511 | + DownSampler(320, device=device, dtype=dtype), |
512 | 512 | PushBlock(), |
513 | 513 | # CrossAttnDownBlock2D |
514 | | - ResnetBlock(320, 640, 1280), |
515 | | - AttentionBlock(8, 80, 640, 1, 768), |
| 514 | + ResnetBlock(320, 640, 1280, device=device, dtype=dtype), |
| 515 | + AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype), |
516 | 516 | PushBlock(), |
517 | | - ResnetBlock(640, 640, 1280), |
518 | | - AttentionBlock(8, 80, 640, 1, 768), |
| 517 | + ResnetBlock(640, 640, 1280, device=device, dtype=dtype), |
| 518 | + AttentionBlock(8, 80, 640, 1, 768, device=device, dtype=dtype), |
519 | 519 | PushBlock(), |
520 | | - DownSampler(640), |
| 520 | + DownSampler(640, device=device, dtype=dtype), |
521 | 521 | PushBlock(), |
522 | 522 | # CrossAttnDownBlock2D |
523 | | - ResnetBlock(640, 1280, 1280), |
524 | | - AttentionBlock(8, 160, 1280, 1, 768), |
| 523 | + ResnetBlock(640, 1280, 1280, device=device, dtype=dtype), |
| 524 | + AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype), |
525 | 525 | PushBlock(), |
526 | | - ResnetBlock(1280, 1280, 1280), |
527 | | - AttentionBlock(8, 160, 1280, 1, 768), |
| 526 | + ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), |
| 527 | + AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype), |
528 | 528 | PushBlock(), |
529 | | - DownSampler(1280), |
| 529 | + DownSampler(1280, device=device, dtype=dtype), |
530 | 530 | PushBlock(), |
531 | 531 | # DownBlock2D |
532 | | - ResnetBlock(1280, 1280, 1280), |
| 532 | + ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), |
533 | 533 | PushBlock(), |
534 | | - ResnetBlock(1280, 1280, 1280), |
| 534 | + ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), |
535 | 535 | PushBlock(), |
536 | 536 | # UNetMidBlock2DCrossAttn |
537 | | - ResnetBlock(1280, 1280, 1280), |
538 | | - AttentionBlock(8, 160, 1280, 1, 768), |
539 | | - ResnetBlock(1280, 1280, 1280), |
| 537 | + ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), |
| 538 | + AttentionBlock(8, 160, 1280, 1, 768, device=device, dtype=dtype), |
| 539 | + ResnetBlock(1280, 1280, 1280, device=device, dtype=dtype), |
540 | 540 | PushBlock() |
541 | 541 | ]) |
542 | 542 |
|
543 | 543 | self.controlnet_blocks = torch.nn.ModuleList([ |
544 | | - torch.nn.Conv2d(320, 320, kernel_size=(1, 1)), |
545 | | - torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False), |
546 | | - torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False), |
547 | | - torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False), |
548 | | - torch.nn.Conv2d(640, 640, kernel_size=(1, 1)), |
549 | | - torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False), |
550 | | - torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False), |
551 | | - torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)), |
552 | | - torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False), |
553 | | - torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False), |
554 | | - torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False), |
555 | | - torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False), |
556 | | - torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False), |
| 544 | + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), device=device, dtype=dtype), |
| 545 | + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), |
| 546 | + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), |
| 547 | + torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), |
| 548 | + torch.nn.Conv2d(640, 640, kernel_size=(1, 1), device=device, dtype=dtype), |
| 549 | + torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), |
| 550 | + torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), |
| 551 | + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), device=device, dtype=dtype), |
| 552 | + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), |
| 553 | + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), |
| 554 | + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), |
| 555 | + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), |
| 556 | + torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False, device=device, dtype=dtype), |
557 | 557 | ]) |
558 | 558 |
|
559 | 559 | def forward( |
|
0 commit comments