-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathgenerate_interactive.py
More file actions
109 lines (90 loc) · 3.3 KB
/
generate_interactive.py
File metadata and controls
109 lines (90 loc) · 3.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import argparse
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
from flux import FluxPipeline
def print_zero(group, *args, **kwargs):
if group.rank() == 0:
flush = kwargs.pop("flush", True)
print(*args, **kwargs, flush=flush)
def quantization_predicate(name, m):
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
def to_latent_size(image_size):
h, w = image_size
h = ((h + 15) // 16) * 16
w = ((w + 15) // 16) * 16
if (h, w) != image_size:
print(
"Warning: The image dimensions need to be divisible by 16px. "
f"Changing size to {h}x{w}."
)
return (h // 8, w // 8)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using FLUX"
)
parser.add_argument("--quantize", "-q", action="store_true")
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
parser.add_argument("--output", default="out.png")
args = parser.parse_args()
flux = FluxPipeline("flux-" + args.model, t5_padding=True)
if args.quantize:
nn.quantize(flux.flow, class_predicate=quantization_predicate)
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
group = mx.distributed.init()
if group.size() > 1:
flux.flow.shard(group)
print_zero(group, "Loading models")
flux.ensure_models_are_loaded()
def print_help():
print_zero(group, "The command list:")
print_zero(group, "- 'q' to exit")
print_zero(group, "- 's HxW' to change the size of the image")
print_zero(group, "- 'n S' to change the number of steps")
print_zero(group, "- 'h' to print this help")
print_zero(group, "FLUX interactive session")
print_help()
seed = 0
size = (512, 512)
latent_size = to_latent_size(size)
steps = 50 if args.model == "dev" else 4
while True:
prompt = input(">> " if group.rank() == 0 else "")
if prompt == "q":
break
if prompt == "h":
print_help()
continue
if prompt.startswith("s "):
size = tuple([int(xi) for xi in prompt[2:].split("x")])
print_zero(group, "Setting the size to", size)
latent_size = to_latent_size(size)
continue
if prompt.startswith("n "):
steps = int(prompt[2:])
print_zero(group, "Setting the steps to", steps)
continue
seed += 1
latents = flux.generate_latents(
prompt,
n_images=1,
num_steps=steps,
latent_size=latent_size,
guidance=4.0,
seed=seed,
)
print_zero(group, "Processing prompt")
mx.eval(next(latents))
print_zero(group, "Generating latents")
for xt in tqdm(latents, total=steps, disable=group.rank() > 0):
mx.eval(xt)
print_zero(group, "Generating image")
xt = flux.decode(xt, latent_size)
xt = (xt * 255).astype(mx.uint8)
mx.eval(xt)
im = Image.fromarray(np.array(xt[0]))
im.save(args.output)
print_zero(group, "Saved at", args.output, end="\n\n")