Skip to content

Commit 25db678

Browse files
ai-edge-botcopybara-github
authored andcommitted
Convert stable diffusion model as inference only
PiperOrigin-RevId: 719393588
1 parent d09e97a commit 25db678

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def convert_stable_diffusion_to_tflite(
8585
clip.TENSOR_NAMES,
8686
)
8787
loader.load(clip_model, strict=False)
88+
clip_model.eval()
8889

8990
diffusion_model = diffusion.Diffusion(
9091
diffusion.get_model_config(batch_size=2, device_type=_DEVICE_TYPE.value)
@@ -93,6 +94,7 @@ def convert_stable_diffusion_to_tflite(
9394
diffusion_ckpt_path, diffusion.TENSOR_NAMES
9495
)
9596
diffusion_loader.load(diffusion_model, strict=False)
97+
diffusion_model.eval()
9698

9799
decoder_model = decoder.Decoder(
98100
decoder.get_model_config(device_type=_DEVICE_TYPE.value)
@@ -101,6 +103,7 @@ def convert_stable_diffusion_to_tflite(
101103
decoder_ckpt_path, decoder.TENSOR_NAMES
102104
)
103105
decoder_loader.load(decoder_model, strict=False)
106+
decoder_model.eval()
104107

105108
# TODO(yichunk): enable image encoder conversion
106109
# if encoder_ckpt_path is not None:

0 commit comments

Comments
 (0)