@@ -182,6 +182,28 @@ def seed_everething(seed: int):
182182 torch .mps .manual_seed (seed )
183183
184184
185+ def create_transformer (ckpt_path : str , precision : str ) -> Transformer3DModel :
186+ if precision == "float8_e4m3fn" :
187+ try :
188+ from q8_kernels .integration .patch_transformer import (
189+ patch_diffusers_transformer as patch_transformer_for_q8_kernels ,
190+ )
191+
192+ transformer = Transformer3DModel .from_pretrained (
193+ ckpt_path , dtype = torch .float8_e4m3fn
194+ )
195+ patch_transformer_for_q8_kernels (transformer )
196+ return transformer
197+ except ImportError :
198+ raise ValueError (
199+ "Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from https://github.com/Lightricks/LTXVideo-Q8-Kernels"
200+ )
201+ elif precision == "bfloat16" :
202+ return Transformer3DModel .from_pretrained (ckpt_path ).to (torch .bfloat16 )
203+ else :
204+ return Transformer3DModel .from_pretrained (ckpt_path )
205+
206+
185207def create_ltx_video_pipeline (
186208 ckpt_path : str ,
187209 precision : str ,
@@ -204,7 +226,7 @@ def create_ltx_video_pipeline(
204226 allowed_inference_steps = configs .get ("allowed_inference_steps" , None )
205227
206228 vae = CausalVideoAutoencoder .from_pretrained (ckpt_path )
207- transformer = Transformer3DModel . from_pretrained (ckpt_path )
229+ transformer = create_transformer (ckpt_path , precision )
208230
209231 # Use constructor if sampler is specified, otherwise use from_pretrained
210232 if sampler == "from_checkpoint" or not sampler :
@@ -247,8 +269,6 @@ def create_ltx_video_pipeline(
247269 prompt_enhancer_llm_tokenizer = None
248270
249271 vae = vae .to (torch .bfloat16 )
250- if precision == "bfloat16" and transformer .dtype != torch .bfloat16 :
251- transformer = transformer .to (torch .bfloat16 )
252272 text_encoder = text_encoder .to (torch .bfloat16 )
253273
254274 # Use submodels for the pipeline
0 commit comments