@@ -417,7 +417,7 @@ def __call__(
417417 cfg_scale_ = cfg_scale if isinstance (cfg_scale , float ) else cfg_scale [0 ]
418418
419419 timestep = timestep * mask [:, :, :, ::2 , ::2 ].flatten () # seq_len
420- timestep = timestep .to (dtype = self .config . model_dtype , device = self .device )
420+ timestep = timestep .to (dtype = self .dtype , device = self .device )
421421 # Classifier-free guidance
422422 noise_pred = self .predict_noise_with_cfg (
423423 model = model ,
@@ -574,6 +574,18 @@ def from_pretrained(cls, model_path_or_config: WanPipelineConfig) -> "WanVideoPi
574574 if config .offload_mode is not None :
575575 pipe .enable_cpu_offload (config .offload_mode )
576576
577+ if config .model_dtype == torch .float8_e4m3fn :
578+ pipe .dtype = torch .bfloat16 # compute dtype
579+ pipe .enable_fp8_autocast (
580+ model_names = ["dit" ], compute_dtype = pipe .dtype , use_fp8_linear = config .use_fp8_linear
581+ )
582+
583+ if config .t5_dtype == torch .float8_e4m3fn :
584+ pipe .dtype = torch .bfloat16 # compute dtype
585+ pipe .enable_fp8_autocast (
586+ model_names = ["text_encoder" ], compute_dtype = pipe .dtype , use_fp8_linear = config .use_fp8_linear
587+ )
588+
577589 if config .parallelism > 1 :
578590 return ParallelWrapper (
579591 pipe ,
0 commit comments