Skip to content

Commit 91a28f6

Browse files
authored
[Bug Fix] fix the initialization bug of all zero (#955)
1 parent 0904366 commit 91a28f6

File tree

7 files changed

+8
-8
lines changed

7 files changed

+8
-8
lines changed

examples/hunyuanvideo-i2v/hyvideo/modules/embed_layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060
# nn.init.zeros_(self.proj.bias)
6161
w = self.proj.weight
6262
w_flatted = w.reshape(w.shape[0], -1)
63-
w.set_data(initializer(XavierUniform(), w_flatted.shape, w_flatted.dtype).reshape(w.shape))
63+
w.set_data(initializer(XavierUniform(), w_flatted.shape, w_flatted.dtype).init_data().reshape(w.shape))
6464

6565
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
6666

examples/opensora_hpcai/opensora/models/stdit/stdit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def _basic_init(module):
449449
w_flatted = w.reshape(w.shape[0], -1)
450450
# FIXME: incompatible in optim parallel mode
451451
# FIXME: impl in torch can be incorrect. can be reshape order mismatch
452-
w.set_data(initializer(XavierUniform(), w_flatted.shape, w_flatted.dtype).reshape(w.shape))
452+
w.set_data(initializer(XavierUniform(), w_flatted.shape, w_flatted.dtype).init_data().reshape(w.shape))
453453

454454
# Initialize timestep embedding MLP:
455455
normal_(self.t_embedder.mlp[0].weight, std=0.02)

examples/opensora_hpcai/opensora/models/stdit/stdit2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def _basic_init(module):
538538
w_flatted = w.reshape(w.shape[0], -1)
539539
# FIXME: incompatible in optim parallel mode
540540
# FIXME: impl in torch can be incorrect. can be reshape order mismatch
541-
w.set_data(initializer(XavierUniform(), w_flatted.shape, w_flatted.dtype).reshape(w.shape))
541+
w.set_data(initializer(XavierUniform(), w_flatted.shape, w_flatted.dtype).init_data().reshape(w.shape))
542542

543543
# Initialize timestep embedding MLP:
544544
normal_(self.t_embedder.mlp[0].weight, std=0.02)

examples/pixart_sigma/pixart/modules/pixart/pixart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def _basic_init(module):
183183
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
184184
w = self.x_embedder.proj.weight.data
185185
w_flatted = w.view(w.shape[0], -1)
186-
w.set_data(initializer(XavierUniform(), w_flatted.shape, w_flatted.dtype).reshape(w.shape))
186+
w.set_data(initializer(XavierUniform(), w_flatted.shape, w_flatted.dtype).init_data().reshape(w.shape))
187187
constant_(self.x_embedder.proj.bias, 0)
188188

189189
# Initialize timestep embedding MLP:
@@ -358,7 +358,7 @@ def _basic_init(module):
358358
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
359359
w = self.x_embedder.proj.weight.data
360360
w_flatted = w.view(w.shape[0], -1)
361-
w.set_data(initializer(XavierUniform(), w_flatted.shape, w_flatted.dtype).reshape(w.shape))
361+
w.set_data(initializer(XavierUniform(), w_flatted.shape, w_flatted.dtype).init_data().reshape(w.shape))
362362
constant_(self.x_embedder.proj.bias, 0)
363363

364364
# Initialize timestep embedding MLP:

mindone/models/dit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def _basic_init(module):
531531
w = self.x_embedder.proj.weight
532532
# xavier_uniform_(w.view(w.shape[0], -1))
533533
w_flatted = w.view(w.shape[0], -1)
534-
w.set_data(initializer(XavierUniform(), w_flatted.shape, w_flatted.dtype).reshape(w.shape))
534+
w.set_data(initializer(XavierUniform(), w_flatted.shape, w_flatted.dtype).init_data().reshape(w.shape))
535535
constant_(self.x_embedder.proj.bias, 0)
536536

537537
# Initialize label embedding table:

mindone/models/latte.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _basic_init(module):
123123
w = self.x_embedder.proj.weight
124124
# xavier_uniform_(w.view(w.shape[0], -1))
125125
w_flatted = w.view(w.shape[0], -1)
126-
w.set_data(initializer(XavierUniform(), w_flatted.shape, w_flatted.dtype).reshape(w.shape))
126+
w.set_data(initializer(XavierUniform(), w_flatted.shape, w_flatted.dtype).init_data().reshape(w.shape))
127127
constant_(self.x_embedder.proj.bias, 0)
128128

129129
# Initialize label embedding table:

mindone/models/mmdit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def _basic_init(module):
350350
w = self.x_embedder.proj.weight
351351
# xavier_uniform_(w.view(w.shape[0], -1))
352352
w_flatted = w.view(w.shape[0], -1)
353-
w.set_data(initializer(XavierUniform(), w_flatted.shape, w_flatted.dtype).reshape(w.shape))
353+
w.set_data(initializer(XavierUniform(), w_flatted.shape, w_flatted.dtype).init_data().reshape(w.shape))
354354
constant_(self.x_embedder.proj.bias, 0)
355355

356356
# Initialize timestep embedding MLP:

0 commit comments

Comments
 (0)