Skip to content

Commit 265d7dc

Browse files
committed
Decoupled model definition and backend lowering
1 parent 0f09020 commit 265d7dc

File tree

3 files changed

+223
-126
lines changed

3 files changed

+223
-126
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import LCMModelLoader, TextEncoderWrapper, UNetWrapper, VAEDecoder
8+
9+
__all__ = ["LCMModelLoader", "TextEncoderWrapper", "UNetWrapper", "VAEDecoder"]
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright (c) Intel Corporation
2+
#
3+
# Licensed under the BSD License (the "License"); you may not use this file
4+
# except in compliance with the License. See the license file found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Stable Diffusion / LCM model definitions.
9+
10+
This module provides reusable model wrappers that can be used with any backend
11+
(OpenVINO, XNNPACK, etc.) for exporting Latent Consistency Models.
12+
"""
13+
14+
import logging
15+
from typing import Any, Optional
16+
17+
import torch
18+
19+
try:
20+
from diffusers import DiffusionPipeline
21+
except ImportError:
22+
raise ImportError(
23+
"Please install diffusers and transformers: pip install diffusers transformers"
24+
)
25+
26+
logger = logging.getLogger(__name__)
27+
28+
29+
class TextEncoderWrapper(torch.nn.Module):
30+
"""Wrapper for CLIP text encoder that extracts last_hidden_state"""
31+
32+
def __init__(self, text_encoder):
33+
super().__init__()
34+
self.text_encoder = text_encoder
35+
36+
def forward(self, input_ids):
37+
# Call text encoder and extract last_hidden_state
38+
output = self.text_encoder(input_ids, return_dict=True)
39+
return output.last_hidden_state
40+
41+
42+
class UNetWrapper(torch.nn.Module):
43+
"""Wrapper for UNet that extracts sample tensor from output"""
44+
45+
def __init__(self, unet):
46+
super().__init__()
47+
self.unet = unet
48+
49+
def forward(self, latents, timestep, encoder_hidden_states):
50+
# Call UNet and extract sample from the output
51+
output = self.unet(latents, timestep, encoder_hidden_states, return_dict=True)
52+
return output.sample
53+
54+
55+
class VAEDecoder(torch.nn.Module):
56+
"""Wrapper for VAE decoder with scaling and normalization"""
57+
58+
def __init__(self, vae):
59+
super().__init__()
60+
self.vae = vae
61+
62+
def forward(self, latents):
63+
# Scale latents
64+
latents = latents / self.vae.config.scaling_factor
65+
# Decode
66+
image = self.vae.decode(latents).sample
67+
# Scale to [0, 1]
68+
image = (image / 2 + 0.5).clamp(0, 1)
69+
return image
70+
71+
72+
class LCMModelLoader:
73+
"""
74+
Backend-agnostic loader for Latent Consistency Model components.
75+
76+
This class handles loading the LCM pipeline from HuggingFace and extracting
77+
individual components (text_encoder, unet, vae) as PyTorch modules ready
78+
for export to any backend.
79+
"""
80+
81+
def __init__(
82+
self,
83+
model_id: str = "SimianLuo/LCM_Dreamshaper_v7",
84+
dtype: torch.dtype = torch.float16,
85+
):
86+
"""
87+
Initialize the LCM model loader.
88+
89+
Args:
90+
model_id: HuggingFace model ID for the LCM model
91+
dtype: Target dtype for the models (fp16 or fp32)
92+
"""
93+
self.model_id = model_id
94+
self.dtype = dtype
95+
self.pipeline: Optional[DiffusionPipeline] = None
96+
self.text_encoder: Any = None
97+
self.unet: Any = None
98+
self.vae: Any = None
99+
self.tokenizer: Any = None
100+
101+
def load_models(self) -> bool:
102+
"""
103+
Load the LCM pipeline and extract components.
104+
105+
Returns:
106+
True if successful, False otherwise
107+
"""
108+
try:
109+
logger.info(f"Loading LCM pipeline: {self.model_id} (dtype: {self.dtype})")
110+
self.pipeline = DiffusionPipeline.from_pretrained(
111+
self.model_id, use_safetensors=True
112+
)
113+
114+
# Extract individual components and convert to desired dtype
115+
self.text_encoder = self.pipeline.text_encoder.to(dtype=self.dtype)
116+
self.unet = self.pipeline.unet.to(dtype=self.dtype)
117+
self.vae = self.pipeline.vae.to(dtype=self.dtype)
118+
self.tokenizer = self.pipeline.tokenizer
119+
120+
# Set models to evaluation mode
121+
self.text_encoder.eval()
122+
self.unet.eval()
123+
self.vae.eval()
124+
125+
logger.info("Successfully loaded all LCM model components")
126+
return True
127+
128+
except Exception as e:
129+
logger.error(f"Failed to load models: {e}")
130+
import traceback
131+
132+
traceback.print_exc()
133+
return False
134+
135+
def get_text_encoder_wrapper(self) -> TextEncoderWrapper:
136+
"""Get wrapped text encoder ready for export"""
137+
if self.text_encoder is None:
138+
raise ValueError("Models not loaded. Call load_models() first.")
139+
return TextEncoderWrapper(self.text_encoder)
140+
141+
def get_unet_wrapper(self) -> UNetWrapper:
142+
"""Get wrapped UNet ready for export"""
143+
if self.unet is None:
144+
raise ValueError("Models not loaded. Call load_models() first.")
145+
return UNetWrapper(self.unet)
146+
147+
def get_vae_decoder(self) -> VAEDecoder:
148+
"""Get wrapped VAE decoder ready for export"""
149+
if self.vae is None:
150+
raise ValueError("Models not loaded. Call load_models() first.")
151+
return VAEDecoder(self.vae)
152+
153+
def get_dummy_inputs(self):
154+
"""
155+
Get dummy inputs for each model component.
156+
157+
Returns:
158+
Dictionary with dummy inputs for text_encoder, unet, and vae_decoder
159+
"""
160+
if self.unet is None:
161+
raise ValueError("Models not loaded. Call load_models() first.")
162+
163+
# Text encoder dummy input
164+
text_encoder_input = torch.ones(1, 77, dtype=torch.long)
165+
166+
# UNet dummy inputs
167+
batch_size = 1
168+
latent_channels = 4
169+
latent_height = 64
170+
latent_width = 64
171+
text_embed_dim = self.unet.config.cross_attention_dim
172+
text_seq_len = 77
173+
174+
unet_inputs = (
175+
torch.randn(
176+
batch_size, latent_channels, latent_height, latent_width, dtype=self.dtype
177+
),
178+
torch.tensor([981]), # Random timestep
179+
torch.randn(batch_size, text_seq_len, text_embed_dim, dtype=self.dtype),
180+
)
181+
182+
# VAE decoder dummy input
183+
vae_input = torch.randn(1, 4, 64, 64, dtype=self.dtype)
184+
185+
return {
186+
"text_encoder": (text_encoder_input,),
187+
"unet": unet_inputs,
188+
"vae_decoder": (vae_input,),
189+
}

0 commit comments

Comments
 (0)