Skip to content

Commit 2f96332

Browse files
haozha111copybara-github
authored andcommitted
improve conversion script for SD a bit, using some pre-defined checkpoint paths.
PiperOrigin-RevId: 729315009
1 parent 337bdee commit 2f96332

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,41 +29,46 @@
2929

3030
_CLIP_CKPT = flags.DEFINE_string(
3131
'clip_ckpt',
32-
None,
32+
os.path.join(
33+
pathlib.Path.home(),
34+
'Downloads/stable-diffusion-v1-5/v1-5-pruned-emaonly.safetensors',
35+
),
3336
help='Path to source CLIP model checkpoint',
34-
required=True,
3537
)
3638

3739
_DIFFUSION_CKPT = flags.DEFINE_string(
3840
'diffusion_ckpt',
39-
None,
41+
os.path.join(
42+
pathlib.Path.home(),
43+
'Downloads/stable-diffusion-v1-5/v1-5-pruned-emaonly.safetensors',
44+
),
4045
help='Path to source diffusion model checkpoint',
41-
required=True,
4246
)
4347

4448
_DECODER_CKPT = flags.DEFINE_string(
4549
'decoder_ckpt',
46-
None,
50+
os.path.join(
51+
pathlib.Path.home(),
52+
'Downloads/stable-diffusion-v1-5/v1-5-pruned-emaonly.safetensors',
53+
),
4754
help='Path to source image decoder model checkpoint',
48-
required=True,
4955
)
5056

5157
_OUTPUT_DIR = flags.DEFINE_string(
5258
'output_dir',
53-
None,
59+
'/tmp/sd_tflite',
5460
help='Path to the converted TF Lite directory.',
55-
required=True,
5661
)
5762

5863
_QUANTIZE = flags.DEFINE_bool(
5964
'quantize',
6065
help='Whether to quantize the model during conversion.',
61-
default=True,
66+
default=False,
6267
)
6368

6469
_DEVICE_TYPE = flags.DEFINE_string(
6570
'device_type',
66-
'cpu',
71+
'gpu',
6772
help='The device type of the model. Currently supported: cpu, gpu.',
6873
)
6974

0 commit comments

Comments
 (0)