diff --git a/xclip.py b/xclip.py new file mode 100644 index 00000000..1cdacc8e --- /dev/null +++ b/xclip.py @@ -0,0 +1,126 @@ +import torch +from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP +from PIL import Image +from einx.backend import Backend + +backend = Backend() +clip = CLIP( + dim_text = 512, + dim_image = 512, + dim_latent = 512, + num_text_tokens = 49408, + text_enc_depth = 6, + text_seq_len = 256, + text_heads = 8, + visual_enc_depth = 6, + visual_image_size = 256, + visual_patch_size = 32, + visual_heads = 8 +).cuda() + +# mock data + +text = torch.randint(0, 49408, (4, 256)).cuda() +images = torch.randn(4, 3, 256, 256).cuda() + +# train +def generate_mock_data(batch_size, seq_len, num_text_tokens, image_size): + text = torch.randint(0, num_text_tokens, (batch_size, seq_len)).cuda() + images = torch.randn(batch_size, 3, image_size, image_size).cuda() + return text, images + +num_batches = 1000 +batch_size = 4 +seq_len = 256 +num_text_tokens = 49408 +image_size = 256 + + +for _ in range(num_batches): + text, images = generate_mock_data(batch_size, seq_len, num_text_tokens, image_size) + + # Calculate loss + loss = clip( + text, + images, + return_loss=True # needs to be set to True to return contrastive loss + ) + + # Backpropagate the loss + loss.backward() + +# Do this for many steps + +# prior networks (with transformer) + +prior_network = DiffusionPriorNetwork( + dim = 512, + depth = 6, + dim_head = 64, + heads = 8 +).cuda() + +diffusion_prior = DiffusionPrior( + net = prior_network, + clip = clip, + timesteps = 1000, + sample_timesteps = 64, + cond_drop_prob = 0.2 +).cuda() + +loss = diffusion_prior(text, images) +loss.backward() + +# do above for many steps ... + +# decoder (with unet) + +unet1 = Unet( + dim = 128, + image_embed_dim = 512, + text_embed_dim = 512, + cond_dim = 128, + channels = 3, + dim_mults=(1, 2, 4, 8), + cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings +).cuda() + +unet2 = Unet( + dim = 16, + image_embed_dim = 512, + cond_dim = 128, + channels = 3, + dim_mults = (1, 2, 4, 8, 16) +).cuda() + +decoder = Decoder( + unet = (unet1, unet2), + image_sizes = (128, 256), + clip = clip, + timesteps = 100, + image_cond_drop_prob = 0.1, + text_cond_drop_prob = 0.5 +).cuda() + +for unet_number in (1, 2): + loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much + loss.backward() + +# do above for many steps + +dalle2 = DALLE2( + prior = diffusion_prior, + decoder = decoder +) + +images = dalle2( + ['cute puppy chasing after a squirrel'], + cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition) +) + +# save your image (in this example, of size 256x256) + +# save your image (in this example, of size 256x256) +generated_image = images[0] # Select the first image from the batch +pil_image = Image.fromarray((generated_image.permute(1, 2, 0).cpu().numpy() * 255).astype('uint8')) +pil_image.save('generated_image.png')