Skip to content

Commit 93c34f3

Browse files
authored
set module device to skip weight init (#207)
* set module device to skip weight init * up * set device default to cuda:0
1 parent f24e582 commit 93c34f3

File tree

2 files changed

+23
-19
lines changed

2 files changed

+23
-19
lines changed

diffsynth_engine/models/wan/wan_image_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ class WanImageEncoder(PreTrainedModel):
439439
def __init__(self, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
440440
super().__init__()
441441
# init model
442-
self.model, self.transforms = clip_xlm_roberta_vit_h_14(dtype=torch.float32, device="cpu")
442+
self.model, self.transforms = clip_xlm_roberta_vit_h_14(dtype=torch.float32, device=device)
443443

444444
def encode_image(self, images: List[torch.Tensor]):
445445
# preprocess

diffsynth_engine/models/wan/wan_text_encoder.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,20 @@ def forward(self, x):
3838

3939

4040
class T5Attention(nn.Module):
41-
def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
41+
def __init__(self, dim, dim_attn, num_heads, dropout=0.0, device="cuda:0"):
4242
assert dim_attn % num_heads == 0
4343
super(T5Attention, self).__init__()
4444
self.dim = dim
4545
self.dim_attn = dim_attn
4646
self.num_heads = num_heads
4747
self.head_dim = dim_attn // num_heads
48+
self.device = device
4849

4950
# layers
50-
self.q = nn.Linear(dim, dim_attn, bias=False)
51-
self.k = nn.Linear(dim, dim_attn, bias=False)
52-
self.v = nn.Linear(dim, dim_attn, bias=False)
53-
self.o = nn.Linear(dim_attn, dim, bias=False)
51+
self.q = nn.Linear(dim, dim_attn, bias=False, device=device)
52+
self.k = nn.Linear(dim, dim_attn, bias=False, device=device)
53+
self.v = nn.Linear(dim, dim_attn, bias=False, device=device)
54+
self.o = nn.Linear(dim_attn, dim, bias=False, device=device)
5455
self.dropout = nn.Dropout(dropout)
5556

5657
def forward(self, x, context=None, mask=None, pos_bias=None):
@@ -90,15 +91,16 @@ def forward(self, x, context=None, mask=None, pos_bias=None):
9091

9192

9293
class T5FeedForward(nn.Module):
93-
def __init__(self, dim, dim_ffn, dropout=0.0):
94+
def __init__(self, dim, dim_ffn, dropout=0.0, device="cuda:0"):
9495
super(T5FeedForward, self).__init__()
9596
self.dim = dim
9697
self.dim_ffn = dim_ffn
98+
self.device = device
9799

98100
# layers
99-
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
100-
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
101-
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
101+
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False, device=device), GELU())
102+
self.fc1 = nn.Linear(dim, dim_ffn, bias=False, device=device)
103+
self.fc2 = nn.Linear(dim_ffn, dim, bias=False, device=device)
102104
self.dropout = nn.Dropout(dropout)
103105

104106
def forward(self, x):
@@ -110,21 +112,22 @@ def forward(self, x):
110112

111113

112114
class T5SelfAttention(nn.Module):
113-
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.0):
115+
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.0, device="cuda:0"):
114116
super(T5SelfAttention, self).__init__()
115117
self.dim = dim
116118
self.dim_attn = dim_attn
117119
self.dim_ffn = dim_ffn
118120
self.num_heads = num_heads
119121
self.num_buckets = num_buckets
120122
self.shared_pos = shared_pos
123+
self.device = device
121124

122125
# layers
123126
self.norm1 = T5LayerNorm(dim)
124-
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
127+
self.attn = T5Attention(dim, dim_attn, num_heads, dropout, device)
125128
self.norm2 = T5LayerNorm(dim)
126-
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
127-
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True)
129+
self.ffn = T5FeedForward(dim, dim_ffn, dropout, device)
130+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, device=device)
128131

129132
def forward(self, x, mask=None, pos_bias=None):
130133
e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1))
@@ -134,15 +137,16 @@ def forward(self, x, mask=None, pos_bias=None):
134137

135138

136139
class T5RelativeEmbedding(nn.Module):
137-
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
140+
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128, device="cuda:0"):
138141
super(T5RelativeEmbedding, self).__init__()
139142
self.num_buckets = num_buckets
140143
self.num_heads = num_heads
141144
self.bidirectional = bidirectional
142145
self.max_dist = max_dist
146+
self.device = device
143147

144148
# layers
145-
self.embedding = nn.Embedding(num_buckets, num_heads)
149+
self.embedding = nn.Embedding(num_buckets, num_heads, device=device)
146150

147151
def forward(self, lq, lk):
148152
device = self.embedding.weight.device
@@ -257,12 +261,12 @@ def __init__(
257261
self.shared_pos = shared_pos
258262

259263
# layers
260-
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim)
261-
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None
264+
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim, device=device)
265+
self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True, device=device) if shared_pos else None
262266
self.dropout = nn.Dropout(dropout)
263267
self.blocks = nn.ModuleList(
264268
[
265-
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
269+
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout, device)
266270
for _ in range(num_layers)
267271
]
268272
)

0 commit comments

Comments
 (0)