From 1f29afbf491d8567d44b32822725b8b355415894 Mon Sep 17 00:00:00 2001 From: llcourage Date: Wed, 14 May 2025 00:10:54 -0400 Subject: [PATCH] Test and fix errors in dalle_mini examples. --- lit_nlp/examples/dalle_mini/data.py | 10 ++++-- lit_nlp/examples/dalle_mini/demo.py | 34 ++++++++++++++++++-- lit_nlp/examples/dalle_mini/model.py | 9 ++---- lit_nlp/examples/dalle_mini/requirements.txt | 2 ++ 4 files changed, 42 insertions(+), 13 deletions(-) diff --git a/lit_nlp/examples/dalle_mini/data.py b/lit_nlp/examples/dalle_mini/data.py index e54b1ca3..e77f29fa 100644 --- a/lit_nlp/examples/dalle_mini/data.py +++ b/lit_nlp/examples/dalle_mini/data.py @@ -7,12 +7,16 @@ class DallePrompts(lit_dataset.Dataset): def __init__(self, prompts: list[str]): - self.examples = [] + self._examples = [] for prompt in prompts: - self.examples.append({"prompt": prompt}) + self._examples.append({"prompt": prompt}) def spec(self) -> lit_types.Spec: return {"prompt": lit_types.TextSegment()} def __iter__(self): - return iter(self.examples) + return iter(self._examples) + + @property + def examples(self): + return self._examples diff --git a/lit_nlp/examples/dalle_mini/demo.py b/lit_nlp/examples/dalle_mini/demo.py index 18cbc885..7b6888e1 100644 --- a/lit_nlp/examples/dalle_mini/demo.py +++ b/lit_nlp/examples/dalle_mini/demo.py @@ -1,9 +1,39 @@ r"""Example for dalle-mini demo model. +First run following command to install required packages: + pip install -r ./lit_nlp/examples/dalle_mini/requirements.txt + To run locally with a small number of examples: python -m lit_nlp.examples.dalle_mini.demo - +By default, this module uses the "cuda" device for image generation. +The `requirements.txt` file installs a CUDA-enabled version of PyTorch for GPU acceleration. + +If you are running on a machine without a compatible GPU or CUDA drivers, +you must switch the device to "cpu" and reinstall the CPU-only version of PyTorch. + +Usage: + - Default: device="cuda" + - On CPU-only machines: + 1. Set device="cpu" during model initialization + 2. Uninstall the CUDA version of PyTorch: + pip uninstall torch + 3. Install the CPU-only version: + pip install torch==2.1.2+cpu --extra-index-url https://download.pytorch.org/whl/cpu + +Example: + >>> model = MinDalle(..., device="cpu") + +Check CUDA availability: + >>> import torch + >>> torch.cuda.is_available() + False # if no GPU support is present + +Error Handling: + - If CUDA is selected but unsupported, you will see: + AssertionError: Torch not compiled with CUDA enabled + - To fix this, either install the correct CUDA-enabled PyTorch or switch to CPU mode. + Then navigate to localhost:5432 to access the demo UI. """ @@ -26,8 +56,6 @@ _FLAGS.set_default("development_demo", True) _FLAGS.set_default("default_layout", "DALLE_LAYOUT") -_FLAGS.DEFINE_integer("grid_size", 4, "The grid size to use for the model.") - _MODELS = (["dalle-mini"],) _CANNED_PROMPTS = ["I have a dream", "I have a shiba dog named cola"] diff --git a/lit_nlp/examples/dalle_mini/model.py b/lit_nlp/examples/dalle_mini/model.py index 487072d6..6bb481c8 100644 --- a/lit_nlp/examples/dalle_mini/model.py +++ b/lit_nlp/examples/dalle_mini/model.py @@ -97,13 +97,8 @@ def tensor_to_pil_image(tensor): return images def input_spec(self): - return { - "grid_size": lit_types.Scalar(), - "temperature": lit_types.Scalar(), - "top_k": lit_types.Scalar(), - "supercondition_factor": lit_types.Scalar(), - } - + return {"prompt": lit_types.TextSegment()} + def output_spec(self): return { "image": lit_types.ImageBytesList(), diff --git a/lit_nlp/examples/dalle_mini/requirements.txt b/lit_nlp/examples/dalle_mini/requirements.txt index b5199a94..c454da7d 100644 --- a/lit_nlp/examples/dalle_mini/requirements.txt +++ b/lit_nlp/examples/dalle_mini/requirements.txt @@ -17,3 +17,5 @@ # Dalle-Mini dependencies min_dalle==0.4.11 +torch==2.1.2+cu118 +--extra-index-url https://download.pytorch.org/whl/cu118