diff --git a/sd/demo.ipynb b/sd/demo.ipynb index 5eb9d56..b2510d0 100644 --- a/sd/demo.ipynb +++ b/sd/demo.ipynb @@ -17,6 +17,7 @@ "import torch\n", "\n", "DEVICE = \"cpu\"\n", + "IDLE_DEVICE = \"cpu\"\n", "\n", "ALLOW_CUDA = False\n", "ALLOW_MPS = False\n", @@ -29,7 +30,9 @@ "\n", "tokenizer = CLIPTokenizer(\"../data/vocab.json\", merges_file=\"../data/merges.txt\")\n", "model_file = \"../data/v1-5-pruned-emaonly.ckpt\"\n", - "models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)\n", + "models = model_loader.preload_models_from_standard_weights(\n", + " model_file, DEVICE, IDLE_DEVICE\n", + ")\n", "\n", "## TEXT TO IMAGE\n", "\n", @@ -67,13 +70,20 @@ " seed=seed,\n", " models=models,\n", " device=DEVICE,\n", - " idle_device=\"cpu\",\n", + " idle_device=IDLE_DEVICE,\n", " tokenizer=tokenizer,\n", ")\n", "\n", "# Combine the input image and the output image into a single image.\n", "Image.fromarray(output_image)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/sd/model_loader.py b/sd/model_loader.py index 254e989..c8397e1 100644 --- a/sd/model_loader.py +++ b/sd/model_loader.py @@ -5,8 +5,8 @@ import model_converter -def preload_models_from_standard_weights(ckpt_path, device): - state_dict = model_converter.load_from_standard_weights(ckpt_path, device) +def preload_models_from_standard_weights(ckpt_path, device, idle_device): + state_dict = model_converter.load_from_standard_weights(ckpt_path, idle_device) encoder = VAE_Encoder().to(device) encoder.load_state_dict(state_dict['encoder'], strict=True) diff --git a/sd/pipeline.py b/sd/pipeline.py index 4c57c94..ed4c4db 100644 --- a/sd/pipeline.py +++ b/sd/pipeline.py @@ -149,6 +149,8 @@ def generate( # (Batch_Size, Channel, Height, Width) -> (Batch_Size, Height, Width, Channel) images = images.permute(0, 2, 3, 1) images = images.to("cpu", torch.uint8).numpy() + + torch.cuda.empty_cache() return images[0] def rescale(x, old_range, new_range, clamp=False):