Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

src
*.egg-info
__pycache__
*/**/__pycache__
outputs
train.bat
logs
gen.bat
gen_ref.bat
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This is an implementtaion of Google's [Dreambooth](https://arxiv.org/abs/2208.12242) with [Stable Diffusion](https://github.com/CompVis/stable-diffusion). The original Dreambooth is based on [Imagen](https://imagen.research.google/) text-to-image model. However, neither the model nor the pre-trained weights of Imagen is available. To enable people to fine-tune a text-to-image model with a few examples, I implemented the idea of Dreambooth on Stable diffusion.

This code repository is based on that of [Textual Inversion](https://github.com/rinongal/textual_inversion). Note that Textual Inversion only optimizes word ebedding, while dreambooth fine-tunes the whole diffusion model.
This code repository is based on that of [Textual Inversion](https://github.com/rinongal/textual_inversion). Note that Textual Inversion only optimizes word embedding, while dreambooth fine-tunes the whole diffusion model.

The implementation makes minimum changes over the official codebase of Textual Inversion. In fact, due to lazyness, some components in Textual Inversion, such as the embedding manager, are not deleted, although they will never be used here.
## Update
Expand Down
5 changes: 3 additions & 2 deletions configs/stable-diffusion/v1-finetune_unfrozen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ data:
target: main.DataModuleFromConfig
params:
batch_size: 1
num_workers: 2
num_workers: 1
wrap: false
train:
target: ldm.data.personalized.PersonalizedBase
Expand Down Expand Up @@ -111,10 +111,11 @@ lightning:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 500
batch_frequency: 200
max_images: 8
increase_log_steps: False

trainer:
benchmark: True
max_steps: 800
# precision: 'bf16'
Binary file removed evaluation/__pycache__/clip_eval.cpython-36.pyc
Binary file not shown.
Binary file removed evaluation/__pycache__/clip_eval.cpython-38.pyc
Binary file not shown.
Binary file removed ldm/__pycache__/util.cpython-36.pyc
Binary file not shown.
Binary file removed ldm/__pycache__/util.cpython-38.pyc
Binary file not shown.
Binary file removed ldm/data/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file removed ldm/data/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file removed ldm/data/__pycache__/base.cpython-36.pyc
Binary file not shown.
Binary file removed ldm/data/__pycache__/base.cpython-38.pyc
Binary file not shown.
Binary file removed ldm/data/__pycache__/personalized.cpython-36.pyc
Binary file not shown.
Binary file removed ldm/data/__pycache__/personalized.cpython-38.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed ldm/models/__pycache__/autoencoder.cpython-36.pyc
Binary file not shown.
Binary file removed ldm/models/__pycache__/autoencoder.cpython-38.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed ldm/models/diffusion/__pycache__/ddim.cpython-36.pyc
Binary file not shown.
Binary file removed ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc
Binary file not shown.
Binary file not shown.
Binary file removed ldm/models/diffusion/__pycache__/ddpm.cpython-36.pyc
Binary file not shown.
Binary file removed ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc
Binary file not shown.
Binary file not shown.
Binary file removed ldm/models/diffusion/__pycache__/plms.cpython-36.pyc
Binary file not shown.
Binary file removed ldm/models/diffusion/__pycache__/plms.cpython-38.pyc
Binary file not shown.
Binary file removed ldm/modules/__pycache__/attention.cpython-36.pyc
Binary file not shown.
Binary file removed ldm/modules/__pycache__/attention.cpython-38.pyc
Binary file not shown.
Binary file removed ldm/modules/__pycache__/ema.cpython-36.pyc
Binary file not shown.
Binary file removed ldm/modules/__pycache__/ema.cpython-38.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed ldm/modules/__pycache__/x_transformer.cpython-36.pyc
Binary file not shown.
Binary file removed ldm/modules/__pycache__/x_transformer.cpython-38.pyc
Binary file not shown.
64 changes: 47 additions & 17 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from inspect import isfunction
import math
import torch
import torch, gc
import torch.nn.functional as F
from torch import nn, einsum
from torch import nn, einsum, autocast
from einops import rearrange, repeat

from ldm.modules.diffusionmodules.util import checkpoint
Expand Down Expand Up @@ -170,28 +170,58 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
def forward(self, x, context=None, mask=None):
h = self.heads

q = self.to_q(x)
q_in = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
k_in = self.to_k(context)
v_in = self.to_v(context)
del context, x

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)

if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)

if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")

if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')

slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale

s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1

r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2

del q, k, v

r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1

return self.to_out(r2)

class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
148 changes: 111 additions & 37 deletions ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pytorch_diffusion + derived encoder decoder
import math
import torch
from re import T
import torch, gc
import torch.nn as nn
import numpy as np
from einops import rearrange
Expand Down Expand Up @@ -32,7 +33,10 @@ def get_timestep_embedding(timesteps, embedding_dim):

def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
t = torch.sigmoid(x)
x *= t
del t
return x


def Normalize(in_channels, num_groups=32):
Expand Down Expand Up @@ -119,26 +123,38 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
padding=0)

def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h1 = x
h2 = self.norm1(h1)
del h1

h3 = nonlinearity(h2)
del h2

h4 = self.conv1(h3)
del h3

if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None]

h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
h5 = self.norm2(h4)
del h4

h6 = nonlinearity(h5)
del h5

h7 = self.dropout(h6)
del h6

h8 = self.conv2(h7)
del h7

if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)

return x+h
return x + h8


class LinAttnBlock(LinearAttention):
Expand Down Expand Up @@ -178,28 +194,65 @@ def __init__(self, in_channels):
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)

# compute attention
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
b, c, h, w = q1.shape

q2 = q1.reshape(b, c, h*w)
del q1

q = q2.permute(0, 2, 1) # b,hw,c
del q2

k = k1.reshape(b, c, h*w) # b,c,hw
del k1

h_ = torch.zeros_like(k, device=q.device)

stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch

tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1

if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))

slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size

w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2

# attend to values
v = v.reshape(b,c,h*w)
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b,c,h,w)
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3

h_ = self.proj_out(h_)
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4

return x+h_
h2 = h_.reshape(b, c, h, w)
del h_

h3 = self.proj_out(h2)
del h2

h3 += x

return h3


def make_attn(in_channels, attn_type="vanilla"):
Expand Down Expand Up @@ -540,31 +593,52 @@ def forward(self, z):
temb = None

# z to block_in
h = self.conv_in(z)
h1 = self.conv_in(z)

# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
h2 = self.mid.block_1(h1, temb)
del h1

h3 = self.mid.attn_1(h2)
del h2

h = self.mid.block_2(h3, temb)
del h3

# prepare for up sampling
gc.collect()
torch.cuda.empty_cache()

# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
t = h
h = self.up[i_level].attn[i_block](t)
del t
if i_level != 0:
h = self.up[i_level].upsample(h)
t = h
h = self.up[i_level].upsample(t)
del t

# end
if self.give_pre_end:
return h

h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
h1 = self.norm_out(h)
del h

h2 = nonlinearity(h1)
del h1

h = self.conv_out(h2)
del h2

if self.tanh_out:
t = h
h = torch.tanh(h)
del t
return h


Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
3 changes: 2 additions & 1 deletion ldm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def log_txt_as_img(wh, xc, size=10):
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
#font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
font = ImageFont.load_default()
nc = int(40 * (wh[0] / 256))
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))

Expand Down
Loading