Skip to content

Commit 6df5641

Browse files
authored
1 parent 0d38c46 commit 6df5641

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

examples/opensora_pku/opensora/models/causalvideovae/model/vae/modeling_wfvae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def sample(self, mean, logvar):
605605
# sample z from latent distribution
606606
logvar = mint.clamp(logvar, -30.0, 20.0)
607607
std = self.exp(0.5 * logvar)
608-
z = mean + std * ops.stop_gradient(self.stdnormal(size=mean.shape))
608+
z = mean + std * ops.stop_gradient(self.stdnormal(mean.shape))
609609

610610
return z
611611

examples/opensora_pku/opensora/train/commons.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def parse_train_args(parser):
141141
parser.add_argument("--dataloader_num_workers", default=12, type=int, help="num workers for dataloder")
142142
parser.add_argument("--max_rowsize", default=32, type=int, help="max rowsize for data loading")
143143
parser.add_argument(
144-
"dataset_iterator_no_copy",
144+
"--dataset_iterator_no_copy",
145145
default=True,
146146
type=str2bool,
147147
help="dataset iterator optimization strategy. Whether dataset iterator creates a Tensor without copy.",

examples/opensora_pku/opensora/train/train_causalvae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def main(args):
191191
device_num=device_num,
192192
rank_id=rank_id,
193193
ds_name="video",
194-
dataset_iterator_do_copy=args.dataset_iterator_no_copy,
194+
dataset_iterator_no_copy=args.dataset_iterator_no_copy,
195195
)
196196
dataset_size = train_loader.get_dataset_size()
197197

examples/opensora_pku/opensora/train/train_t2v_diffusers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def main(args):
323323
collate_fn=collate_fn,
324324
sampler=sampler,
325325
column_names=["pixel_values", "attention_mask", "text_embed", "encoder_attention_mask"],
326-
dataset_iterator_do_copy=args.dataset_iterator_no_copy,
326+
dataset_iterator_no_copy=args.dataset_iterator_no_copy,
327327
)
328328
dataloader_size = dataloader.get_dataset_size()
329329
assert (
@@ -358,7 +358,7 @@ def main(args):
358358
collate_fn=collate_fn,
359359
sampler=sampler,
360360
column_names=["pixel_values", "attention_mask", "text_embed", "encoder_attention_mask"],
361-
dataset_iterator_do_copy=args.dataset_iterator_no_copy,
361+
dataset_iterator_no_copy=args.dataset_iterator_no_copy,
362362
)
363363
val_dataloader_size = val_dataloader.get_dataset_size()
364364
assert (

0 commit comments

Comments
 (0)