Skip to content

Commit 6983164

Browse files
enable text cfg, some config definitions
1 parent 0d99a39 commit 6983164

File tree

7 files changed

+162
-44
lines changed

7 files changed

+162
-44
lines changed

eole/config/inference.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,59 @@ class DecodingConfig(Config):
8181
align_debug: bool = Field(default=False, description="Print best align for each word.")
8282

8383

84+
class ImageGenerationConfig(Config):
85+
"""
86+
Let's centralize image generation related stuff here.
87+
This is not a complete config, but rather a subset of options
88+
that are relevant for image generation tasks.
89+
Used as mixin for InferenceConfig for now, but might be properly nested at some point.
90+
"""
91+
92+
# image generation specific stuff, might move elsewhere
93+
image_generation: bool | None = Field(
94+
default=False,
95+
description="Generate image from text input. "
96+
"This will only work if the model is trained for image generation.",
97+
)
98+
image_width: int | None = Field(
99+
default=1024,
100+
description="Width of the generated image. "
101+
"This will only work if the model is trained for image generation.",
102+
)
103+
image_height: int | None = Field(
104+
default=1024,
105+
description="Height of the generated image. "
106+
"This will only work if the model is trained for image generation.",
107+
)
108+
cfg_text_scale: float | None = Field(
109+
default=1.0,
110+
description="Classifier-free guidance scale for text input. "
111+
)
112+
cfg_image_scale: float | None = Field(
113+
default=1.0,
114+
description="Classifier-free guidance scale for image input. "
115+
)
116+
cfg_interval_min: float | None = Field(
117+
default=0.0,
118+
description="Minimum classifier-free guidance interval. "
119+
)
120+
cfg_interval_max: float | None = Field(
121+
default=1.0,
122+
description="Maximum classifier-free guidance interval. "
123+
)
124+
timestep_shift: float | None = Field(
125+
default=1.0,
126+
description="Shift the timestep for image generation. "
127+
)
128+
num_timesteps: int | None = Field(
129+
default=50,
130+
description="Number of timesteps for image generation. "
131+
)
132+
133+
134+
84135
# in legacy opts, decoding config is separated (probably to be used elsewhere)
85-
class InferenceConfig(RunningConfig, DecodingConfig, LoRaConfig, QuantizeConfig):
136+
class InferenceConfig(RunningConfig, DecodingConfig, LoRaConfig, QuantizeConfig, ImageGenerationConfig):
86137

87138
model_config = get_config_dict()
88139
model_config["arbitrary_types_allowed"] = True # to allow torch.dtype
@@ -111,23 +162,6 @@ class InferenceConfig(RunningConfig, DecodingConfig, LoRaConfig, QuantizeConfig)
111162
description="Optional EOS tokens that would stop generation, e.g. <|eot_id|> for Llama3",
112163
)
113164

114-
# image generation specific stuff, might move elsewhere
115-
image_generation: bool | None = Field(
116-
default=False,
117-
description="Generate image from text input. "
118-
"This will only work if the model is trained for image generation.",
119-
)
120-
image_width: int | None = Field(
121-
default=1024,
122-
description="Width of the generated image. "
123-
"This will only work if the model is trained for image generation.",
124-
)
125-
image_height: int | None = Field(
126-
default=1024,
127-
description="Height of the generated image. "
128-
"This will only work if the model is trained for image generation.",
129-
)
130-
131165
def get_model_path(self):
132166
return self.model_path[0]
133167

eole/decoders/transformer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ def forward(self, layer_in, **kwargs):
138138
else:
139139
norm_layer_in = self.input_layernorm(layer_in)
140140

141-
print("NORM_LAYER_IN:", norm_layer_in.shape, norm_layer_in.sum(), norm_layer_in)
142-
print("NORM_LAYER_IN img:", norm_layer_in[:, -4098:, :].shape, norm_layer_in[:, -4098:, :].sum(), norm_layer_in[:, -4098:, :])
141+
# print("NORM_LAYER_IN:", norm_layer_in.shape, norm_layer_in.sum(), norm_layer_in)
142+
# print("NORM_LAYER_IN img:", norm_layer_in[:, -4098:, :].shape, norm_layer_in[:, -4098:, :].sum(), norm_layer_in[:, -4098:, :])
143143

144144
self_attn, attns = self.self_attn(
145145
norm_layer_in,
@@ -418,7 +418,7 @@ def forward(self, emb, **kwargs):
418418
attn_mask = attn_mask[:, :, :, -self.sliding_window :]
419419

420420
for i, layer in enumerate(self.transformer_layers):
421-
print(f"\n=================\nLAYER {i}\n=================\n")
421+
# print(f"\n=================\nLAYER {i}\n=================\n")
422422
emb, attn = layer(
423423
emb,
424424
enc_out=enc_out if enc_out is not None else emb,
@@ -431,7 +431,7 @@ def forward(self, emb, **kwargs):
431431
),
432432
**kwargs,
433433
)
434-
print("EMB:", emb.shape, emb.sum(), emb)
434+
# print("EMB:", emb.shape, emb.sum(), emb)
435435
if with_align:
436436
attn_align = layer.get_attn_align(
437437
emb,

eole/models/model.py

Lines changed: 98 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
import math
3232
from PIL import Image
33-
33+
from tqdm import tqdm
3434

3535
def build_encoder(model_config, running_config=None):
3636
"""
@@ -1257,7 +1257,7 @@ def generate_image(self, text_src, init_noise, position_ids, num_timesteps=20, t
12571257
0, num_image_tokens, device=device
12581258
)
12591259

1260-
for i, t in enumerate(timesteps):
1260+
for i, t in tqdm(enumerate(timesteps)):
12611261
timestep = torch.tensor([t] * num_image_tokens, device=device)
12621262
if t > min_cfg and t <= max_cfg:
12631263
cfg_text_scale_ = cfg_text_scale
@@ -1285,51 +1285,69 @@ def generate_image(self, text_src, init_noise, position_ids, num_timesteps=20, t
12851285
# no real multi image support for now, so we just return the first one
12861286
return output[0]
12871287

1288-
def forward_image_gen(self, text_src, x_t, timestep, text_ids, text_indices, image_indices, seqlens, image_position_ids, position_ids):
1288+
def forward_image_gen(
1289+
self,
1290+
text_src,
1291+
x_t,
1292+
timestep,
1293+
text_ids,
1294+
text_indices,
1295+
image_indices,
1296+
seqlens,
1297+
image_position_ids,
1298+
position_ids,
1299+
# cfg_text_scale=1.0,
1300+
cfg_text_scale=4.0,
1301+
cfg_img_scale=1.0,
1302+
cfg_renorm_type="global",
1303+
cfg_renorm_min=0.0,
1304+
):
12891305
"""
12901306
(Somewhat corresponds to bagel._forward_flow at high level.)
12911307
"""
1292-
print("TEXT_SRC:", text_src.shape, text_src)
1293-
print("TEXT_IDS:", text_ids)
1308+
# print("TEXT_SRC:", text_src.shape, text_src)
1309+
# print("TEXT_IDS:", text_ids)
12941310
text_embeddings = self.tgt_emb(text_ids)
12951311
text_prompt_emb = self.tgt_emb(text_src)
12961312

1297-
print("TEXT_EMBEDDINGS:", text_embeddings.shape, text_embeddings.sum(), text_embeddings.dtype)
1313+
# print("TEXT_EMBEDDINGS:", text_embeddings.shape, text_embeddings.sum(), text_embeddings.dtype)
12981314
sequence = text_embeddings.new_zeros((sum(seqlens), self.hidden_size))
1299-
print("SEQUENCE:", sequence.shape, sequence.sum(), sequence.dtype)
1315+
# print("SEQUENCE:", sequence.shape, sequence.sum(), sequence.dtype)
13001316
sequence[text_indices] = text_embeddings
13011317

13021318

1303-
print("IMAGE_POSITION_IDS:", image_position_ids)
1319+
# print("IMAGE_POSITION_IDS:", image_position_ids)
13041320
position_embeddings = self.latent_pos_embed(image_position_ids)
1305-
print("POSITION_EMBEDDINGS:", position_embeddings.shape, position_embeddings.sum(), position_embeddings.dtype)
1321+
# print("POSITION_EMBEDDINGS:", position_embeddings.shape, position_embeddings.sum(), position_embeddings.dtype)
13061322
timestep_embeddings = self.time_embedder(timestep)
1307-
print("TIMESTEP_EMBEDDINGS:", timestep_embeddings.shape, timestep_embeddings.sum(), timestep_embeddings.dtype)
1308-
print("X_T:", x_t.shape, x_t.sum(), x_t.dtype)
1323+
# print("TIMESTEP_EMBEDDINGS:", timestep_embeddings.shape, timestep_embeddings.sum(), timestep_embeddings.dtype)
1324+
# print("X_T:", x_t.shape, x_t.sum(), x_t.dtype)
13091325
x_t = self.vae2llm(x_t) + timestep_embeddings + position_embeddings
13101326
sequence[image_indices] = x_t
13111327

13121328

13131329

13141330
sequence = sequence.unsqueeze(0)
13151331

1316-
print("TEXT_PROMPT_EMBED:", text_prompt_emb.shape, text_prompt_emb.sum(), text_prompt_emb.dtype)
1317-
print("SEQUENCE before text prompt:", sequence.shape, sequence.sum(), sequence.dtype)
1332+
# used for CFG
1333+
sequence_without_text = sequence.clone()
1334+
1335+
# print("TEXT_PROMPT_EMBED:", text_prompt_emb.shape, text_prompt_emb.sum(), text_prompt_emb.dtype)
1336+
# print("SEQUENCE before text prompt:", sequence.shape, sequence.sum(), sequence.dtype)
13181337
sequence = torch.cat((text_prompt_emb, sequence), dim=1)
1319-
print("SEQUENCE after text prompt:", sequence.shape, sequence.sum(), sequence.dtype)
1338+
# print("SEQUENCE after text prompt:", sequence.shape, sequence.sum(), sequence.dtype)
13201339

1321-
print("DECODER IN:", sequence.shape, sequence.sum(), sequence)
1340+
# print("DECODER IN:", sequence.shape, sequence.sum(), sequence)
13221341

13231342
offset_image_indices = [i + text_src.size(1) for i in image_indices]
13241343
offset_text_indices = list(range(text_src.size(1))) + [i + text_src.size(1) for i in text_indices]
1325-
print("OFFSET IMAGE INDICES:", len(offset_image_indices), offset_image_indices)
1326-
print("OFFSET TEXT INDICES:", len(offset_text_indices), offset_text_indices)
1344+
# print("OFFSET IMAGE INDICES:", len(offset_image_indices), offset_image_indices)
1345+
# print("OFFSET TEXT INDICES:", len(offset_text_indices), offset_text_indices)
13271346
output, _ = self.decoder(
13281347
sequence,
13291348
step=0, # not sure
13301349
enc_out=None,
13311350
src_len=seqlens,
1332-
with_align=False,
13331351
# tgt_pad_mask=None, # TODO: handle padding mask properly
13341352
tgt_pad_mask=torch.zeros((sequence.size(0), sequence.size(1))).to(dtype=torch.bool, device=sequence.device), # no padding
13351353
text_indices=offset_text_indices,
@@ -1347,6 +1365,68 @@ def forward_image_gen(self, text_src, x_t, timestep, text_ids, text_indices, ima
13471365
print("V_T before cfg:", v_t.shape, v_t.sum(), v_t.dtype)
13481366

13491367
# TODO: additional conditions for cfg_text_scale / cfg_img_scale ?
1368+
if cfg_text_scale > 1.0:
1369+
cfg_text_output, _ = self.decoder(
1370+
sequence_without_text,
1371+
step=0,
1372+
enc_out=None,
1373+
src_len=sequence_without_text.size(1),
1374+
tgt_pad_mask=torch.zeros((sequence_without_text.size(0), sequence_without_text.size(1))).to(
1375+
dtype=torch.bool, device=sequence_without_text.device
1376+
), # no padding
1377+
text_indices=text_indices,
1378+
image_indices=image_indices,
1379+
decoder_in=torch.tensor([[self.image_token_id] * (len(image_indices) + 2)], device=text_src.device),
1380+
image_token_id=self.image_token_id,
1381+
positions=torch.zeros((sequence_without_text.size(1)), device=text_src.device),
1382+
# TODO: find a way to disable cache update for such calls (might be an issue for more complex queries downstream)
1383+
)
1384+
print("CFG TEXT OUTPUT:", cfg_text_output.shape, cfg_text_output.sum(), cfg_text_output.dtype)
1385+
cfg_text_v_t = self.llm2vae(cfg_text_output)
1386+
cfg_text_v_t = cfg_text_v_t.squeeze(0)
1387+
cfg_text_v_t = cfg_text_v_t[image_indices] # select only image tokens
1388+
print("CFG TEXT V_T:", cfg_text_v_t.shape, cfg_text_v_t.sum(), cfg_text_v_t.dtype)
1389+
1390+
if cfg_img_scale > 1.0:
1391+
cfg_img_v_t = v_t.clone()
1392+
# this is actually useful only for the image editing case (input=text+image, output=image),
1393+
# which is still to be investigated
1394+
pass
1395+
1396+
# cfg renorm stuff
1397+
if cfg_text_scale > 1.0:
1398+
print("CFG_TEXT_SCALE > 1.0")
1399+
print("CFG_RENORM_TYPE:", cfg_renorm_type)
1400+
if cfg_renorm_type == "text_channel":
1401+
v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
1402+
norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
1403+
norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True)
1404+
scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
1405+
v_t_text = v_t_text_ * scale
1406+
if cfg_img_scale > 1.0:
1407+
v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t)
1408+
else:
1409+
v_t = v_t_text
1410+
else:
1411+
v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
1412+
1413+
if cfg_img_scale > 1.0:
1414+
v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t)
1415+
else:
1416+
v_t_ = v_t_text_
1417+
1418+
# NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit
1419+
if cfg_renorm_type == "global":
1420+
norm_v_t = torch.norm(v_t)
1421+
norm_v_t_ = torch.norm(v_t_)
1422+
elif cfg_renorm_type == "channel":
1423+
norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
1424+
norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True)
1425+
else:
1426+
raise NotImplementedError(f"{cfg_renorm_type} is not suppoprted")
1427+
scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
1428+
v_t = v_t_ * scale
1429+
13501430

13511431
return v_t
13521432

eole/modules/rope.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@ def forward_1d(self, maxseqlen, step=0, prefetch=1024, offset=32, positions=None
197197
tmax += self.model_config.rope_config.tmax_index
198198

199199
rope = torch.outer(tmax, self.inv_freq.to(device))
200-
print("ROPE freqs:", rope.shape, rope.sum(), rope)
201200
cos = torch.cos(rope)
202201
sin = torch.sin(rope)
203202
cos = torch.cat((cos, cos), dim=-1).to(dtype) # Double the size by repeating `cos`

eole/predict/inference.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(
9292
image_generation=False,
9393
image_width=1024,
9494
image_height=1024,
95+
num_timesteps=20,
9596
):
9697
self.model = model
9798
self.vocabs = vocabs
@@ -174,6 +175,7 @@ def __init__(
174175
self.image_generation = image_generation
175176
self.image_width = image_width
176177
self.image_height = image_height
178+
self.num_timesteps = num_timesteps
177179

178180
@classmethod
179181
def from_config(
@@ -253,6 +255,7 @@ def from_config(
253255
image_generation=config.image_generation,
254256
image_width=config.image_width,
255257
image_height=config.image_height,
258+
num_timesteps=config.num_timesteps,
256259
)
257260

258261
def _log(self, msg):
@@ -663,7 +666,7 @@ def _decode_and_generate(
663666
decoder_in,
664667
init_noise,
665668
position_ids,
666-
num_timesteps=50,
669+
num_timesteps=self.num_timesteps,
667670
)
668671
image = self.model.decode_image(latent, self.image_height, self.image_width)
669672
image.save("generated_image.png")

recipes/bagel/generated_image.png

-244 KB
Loading

recipes/bagel/test_bagel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
image_generation=True,
3838
image_width=1024,
3939
image_height=1024,
40+
# num_timesteps=10,
41+
# num_timesteps=30,
42+
num_timesteps=50,
4043
# self_attn_backend="flash", # not properly supported (mixed masking)
4144
)
4245

@@ -49,8 +52,7 @@
4952
print(engine.predictor.model)
5053
engine.predictor.model.count_parameters()
5154

52-
# prompt = "A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere."
53-
prompt = "A breathtaking photorealistic landscape of a windswept coastal cliff at golden hour. The scene features jagged rocks covered in moss, waves crashing below with mist rising, and seabirds flying overhead. The lighting is warm and natural, casting long shadows and reflecting on wet surfaces. The level of detail is ultra high, with textures of stone, water, and clouds rendered realistically, evoking a feeling of awe and solitude."
55+
prompt = "A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere."
5456

5557
# test_input = [{
5658
# "text": f"<|im_start|>{prompt}<|im_end|><|im_start|>"

0 commit comments

Comments
 (0)