Skip to content

Commit ee87faf

Browse files
llcourageLIT team
authored andcommitted
Test and fix errors in dalle_mini examples. Now it should works fine locally.
PiperOrigin-RevId: 758762358
1 parent f2bde3d commit ee87faf

File tree

4 files changed

+47
-13
lines changed

4 files changed

+47
-13
lines changed

lit_nlp/examples/dalle_mini/data.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,19 @@
55

66

77
class DallePrompts(lit_dataset.Dataset):
8+
"""DallePrompts is a dataset that contains a list of prompts.
9+
10+
It is used to generate images using the dalle-mini model.
11+
"""
812

913
def __init__(self, prompts: list[str]):
10-
self.examples = []
14+
self._examples = []
1115
for prompt in prompts:
12-
self.examples.append({"prompt": prompt})
16+
self._examples.append({"prompt": prompt})
17+
18+
@classmethod
19+
def init_spec(cls) -> lit_types.Spec:
20+
return {"prompt": lit_types.TextSegment(required=True)}
1321

1422
def spec(self) -> lit_types.Spec:
1523
return {"prompt": lit_types.TextSegment()}
16-
17-
def __iter__(self):
18-
return iter(self.examples)

lit_nlp/examples/dalle_mini/demo.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,42 @@
11
r"""Example for dalle-mini demo model.
22
3+
First run following command to install required packages:
4+
pip install -r ./lit_nlp/examples/dalle_mini/requirements.txt
5+
36
To run locally with a small number of examples:
47
python -m lit_nlp.examples.dalle_mini.demo
58
9+
By default, this module uses the "cuda" device for image generation.
10+
The `requirements.txt` file installs a CUDA-enabled version of PyTorch for GPU
11+
acceleration.
12+
13+
If you are running on a machine without a compatible GPU or CUDA drivers,
14+
you must switch the device to "cpu" and reinstall the CPU-only version of
15+
PyTorch.
16+
17+
Usage:
18+
- Default: device="cuda"
19+
- On CPU-only machines:
20+
1. Set device="cpu" during model initialization
21+
2. Uninstall the CUDA version of PyTorch:
22+
pip uninstall torch
23+
3. Install the CPU-only version:
24+
pip install torch==2.1.2+cpu --extra-index-url
25+
https://download.pytorch.org/whl/cpu
26+
27+
Example:
28+
>>> model = MinDalle(..., device="cpu")
29+
30+
Check CUDA availability:
31+
>>> import torch
32+
>>> torch.cuda.is_available()
33+
False # if no GPU support is present
34+
35+
Error Handling:
36+
- If CUDA is selected but unsupported, you will see:
37+
AssertionError: Torch not compiled with CUDA enabled
38+
- To fix this, either install the correct CUDA-enabled PyTorch or switch to
39+
CPU mode.
640
741
Then navigate to localhost:5432 to access the demo UI.
842
"""
@@ -26,8 +60,6 @@
2660
_FLAGS.set_default("development_demo", True)
2761
_FLAGS.set_default("default_layout", "DALLE_LAYOUT")
2862

29-
_FLAGS.DEFINE_integer("grid_size", 4, "The grid size to use for the model.")
30-
3163
_MODELS = (["dalle-mini"],)
3264

3365
_CANNED_PROMPTS = ["I have a dream", "I have a shiba dog named cola"]

lit_nlp/examples/dalle_mini/model.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,7 @@ def tensor_to_pil_image(tensor):
9797
return images
9898

9999
def input_spec(self):
100-
return {
101-
"grid_size": lit_types.Scalar(),
102-
"temperature": lit_types.Scalar(),
103-
"top_k": lit_types.Scalar(),
104-
"supercondition_factor": lit_types.Scalar(),
105-
}
100+
return {"prompt": lit_types.TextSegment()}
106101

107102
def output_spec(self):
108103
return {

lit_nlp/examples/dalle_mini/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@
1717

1818
# Dalle-Mini dependencies
1919
min_dalle==0.4.11
20+
torch==2.1.2+cu118
21+
--extra-index-url https://download.pytorch.org/whl/cu118

0 commit comments

Comments
 (0)