-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathdemo_sdxl.py
More file actions
134 lines (106 loc) · 5.1 KB
/
demo_sdxl.py
File metadata and controls
134 lines (106 loc) · 5.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import json
import gradio as gr
import os
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from src.run_llava_gen import eval_model
from src.resampler import Resampler, Resampler_cross
from src.utils import get_ip_feat, get_random_instruction
from transformers import CLIPVisionModel, CLIPImageProcessor
from diffusers import StableDiffusionXLPipeline
from third_party.ip_adapter import IPAdapterPlusXL_fp16 as IPAdapterPlusXL
ckpts = "./ckpts"
with open("assets/edit_instructions.json", 'r') as file:
instructions = json.load(file)
weight_dtype = torch.float16
# try different devices for mllm and diffusion when facing oom issues
lm_device = "cuda:0"
sdxl_device = "cuda:0"
model_path = os.path.join(ckpts, "llava")
lm_tokenizer, lm_model, lm_image_processor, lm_context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=get_model_name_from_path(model_path),
device_map=lm_device
)
lm_model = lm_model.to(lm_device, dtype=torch.bfloat16)
lm_model.eval()
sdxl_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(sdxl_model_id, torch_dtype=weight_dtype, variant="fp16", use_safetensors=True).to(sdxl_device)
ip_image_processor = CLIPImageProcessor()
ip_resampler = Resampler(dim=1280, depth=4, dim_head=64, heads=20, num_queries=16, embedding_dim=1280, output_dim=2048, ff_mult=4,)
aligner = Resampler_cross(dim=2048, depth=4, dim_head=64, heads=20, num_queries=16, embedding_dim=5120, output_dim=2048, ff_mult=4,)
ip_clip = CLIPVisionModel.from_pretrained(os.path.join(ckpts, "image_encoder"))
ip_resampler.load_state_dict(torch.load(os.path.join(ckpts, "resampler.pth")))
ip_resampler.requires_grad_(False)
ip_clip.requires_grad_(False)
ip_resampler = ip_resampler.to(sdxl_device, dtype=weight_dtype)
ip_clip = ip_clip.to(sdxl_device, dtype=weight_dtype)
aligner = aligner.to(sdxl_device, dtype=torch.float32)
ip_resampler.eval()
ip_clip.eval()
aligner.eval()
pipe = StableDiffusionXLPipeline.from_pretrained(sdxl_model_id, torch_dtype=torch.float16, variant="fp16", add_watermarker=False, use_safetensors=True).to(sdxl_device)
ip_model = IPAdapterPlusXL(pipe, os.path.join(ckpts, "image_encoder"), os.path.join(ckpts, "ip-adapter-plus_sdxl_vit-h.bin"), sdxl_device, num_tokens=16)
model_ckpt = os.path.join(ckpts, "aligner_sdxl.pth")
model_state_dict = torch.load(model_ckpt, map_location='cpu')
new_state_dict = {k[7:] if k.startswith('module.') else k: v for k, v in model_state_dict.items()}
aligner.load_state_dict(new_state_dict)
def generate_images(text_prompt, seed):
sdxl_image = sdxl_pipe(text_prompt, generator=torch.Generator().manual_seed(seed), num_inference_steps=50, guidance_scale=7.5).images[0]
ip_feat = get_ip_feat(sdxl_image, ip_image_processor, ip_clip, ip_resampler)
random_instruction = get_random_instruction(instructions)
prompt = random_instruction.replace("{prompt}", text_prompt)
lm_args = type('Args', (), {
"model_path": model_path,
"model_base": None,
"model_name": get_model_name_from_path(model_path),
"query": prompt,
"conv_mode": None,
"image_file": [sdxl_image],
"sep": ",",
"temperature": 0,
"top_p": None,
"num_beams": 1,
"max_new_tokens": 512
})()
with torch.no_grad():
hidden_states, _ = eval_model(lm_args, lm_tokenizer, lm_model, lm_image_processor, lm_context_len, dtype=torch.bfloat16, generate=False, layer=-1)
hidden_states = hidden_states.to(dtype=torch.float32, device=sdxl_device)
ip_feat = ip_feat.to(dtype=torch.float32, device=sdxl_device)
for _ in range(3):
ip_feat = aligner(hidden_states, ip_feat)
aligned_ip_feat = ip_feat.to(device=sdxl_device, dtype=weight_dtype)
generated_image = ip_model.generate_from_feat(
feat=aligned_ip_feat,
num_samples=1,
num_inference_steps=50,
scale=0.2,
prompt=text_prompt,
seed=seed,
guidance_scale=7.5
)
return sdxl_image, generated_image[0]
with gr.Blocks() as demo:
with gr.Row():
prompt = gr.Textbox(label="Prompt", value="A robot penguin wearing a top hat and playing a vintage trumpet under a rainbow.")
seed = gr.Number(label="Seed", value=306)
with gr.Row():
output_sdxl = gr.Image(label="Image Generated by SDXL", type="pil")
output_img = gr.Image(label="Image Re-generated by IMG (Ours)", type="pil")
generate_btn = gr.Button("Generate")
gr.Examples(
examples=[
["A robot penguin wearing a top hat and playing a vintage trumpet under a rainbow.", 306],
["Photograph of a wall along a city street with a watercolor mural of foxes in a jazz band.", 305],
],
inputs=[prompt, seed],
outputs=[output_sdxl, output_img],
fn=generate_images,
)
generate_btn.click(fn=generate_images, inputs=[prompt, seed], outputs=[output_sdxl, output_img])
demo.launch(server_name="0.0.0.0",
server_port=443,
share=False,
inbrowser=False)