Skip to content

Commit ccfc030

Browse files
ybittermanyoavhacohen
authored andcommitted
Inference: Integrate fp8 kernels
1 parent bdc8f01 commit ccfc030

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

ltx_video/inference.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,28 @@ def seed_everething(seed: int):
182182
torch.mps.manual_seed(seed)
183183

184184

185+
def create_transformer(ckpt_path: str, precision: str) -> Transformer3DModel:
186+
if precision == "float8_e4m3fn":
187+
try:
188+
from q8_kernels.integration.patch_transformer import (
189+
patch_diffusers_transformer as patch_transformer_for_q8_kernels,
190+
)
191+
192+
transformer = Transformer3DModel.from_pretrained(
193+
ckpt_path, dtype=torch.float8_e4m3fn
194+
)
195+
patch_transformer_for_q8_kernels(transformer)
196+
return transformer
197+
except ImportError:
198+
raise ValueError(
199+
"Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from https://github.com/Lightricks/LTXVideo-Q8-Kernels"
200+
)
201+
elif precision == "bfloat16":
202+
return Transformer3DModel.from_pretrained(ckpt_path).to(torch.bfloat16)
203+
else:
204+
return Transformer3DModel.from_pretrained(ckpt_path)
205+
206+
185207
def create_ltx_video_pipeline(
186208
ckpt_path: str,
187209
precision: str,
@@ -204,7 +226,7 @@ def create_ltx_video_pipeline(
204226
allowed_inference_steps = configs.get("allowed_inference_steps", None)
205227

206228
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
207-
transformer = Transformer3DModel.from_pretrained(ckpt_path)
229+
transformer = create_transformer(ckpt_path, precision)
208230

209231
# Use constructor if sampler is specified, otherwise use from_pretrained
210232
if sampler == "from_checkpoint" or not sampler:
@@ -247,8 +269,6 @@ def create_ltx_video_pipeline(
247269
prompt_enhancer_llm_tokenizer = None
248270

249271
vae = vae.to(torch.bfloat16)
250-
if precision == "bfloat16" and transformer.dtype != torch.bfloat16:
251-
transformer = transformer.to(torch.bfloat16)
252272
text_encoder = text_encoder.to(torch.bfloat16)
253273

254274
# Use submodels for the pipeline

0 commit comments

Comments
 (0)