Skip to content

Commit a5c811e

Browse files
Merge pull request #283 from StephanAkkerman/feat/prompt-improv
Improve the prompts
2 parents 41e1dbd + bd58d7d commit a5c811e

File tree

10 files changed

+915
-166
lines changed

10 files changed

+915
-166
lines changed

.github/workflows/pyversions.yml

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,16 @@ jobs:
2828
python-version: ${{ matrix.python-version }}
2929
cache: 'pip' # enable built-in cache
3030
cache-dependency-path: backend/requirements.txt
31-
32-
- name: Resolve dependencies only
31+
32+
- name: Prepare CI-only requirements
3333
run: |
3434
cd backend
35-
python -m pip install --upgrade pip
36-
# Resolve requirements and editable package - no wheels pulled, no files installed
37-
python -m pip install --dry-run -e . -r requirements.txt
38-
35+
# remove any "nunchaku @ …" lines
36+
grep -v '^nunchaku @ ' requirements.txt > requirements-ci.txt
3937
38+
- name: Resolve everything (dry-run)
39+
run: |
40+
cd backend
41+
python -m pip install --upgrade pip setuptools wheel
42+
# install everything *except* nunchaku
43+
pip install --dry-run -r requirements-ci.txt

backend/mnemorai/constants/languages.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,156 @@
55
with open(config.get("G2P").get("LANGUAGE_JSON")) as f:
66
G2P_LANGCODES = json.load(f)
77
G2P_LANGUAGES: dict = dict(map(reversed, G2P_LANGCODES.items()))
8+
9+
EPITRAN_LANGCODES = {
10+
"aar-Latn": "Afar",
11+
"afr-Latn": "Afrikaans",
12+
"aii-Syrc": "Assyrian Neo-Aramaic",
13+
"amh-Ethi": "Amharic",
14+
"amh-Ethi-pp": "Amharic (more phonetic)",
15+
"amh-Ethi-red": "Amharic (reduced)",
16+
"ara-Arab": "Literary Arabic",
17+
"ava-Cyrl": "Avaric",
18+
"aze-Cyrl": "Azerbaijani (Cyrillic)",
19+
"aze-Latn": "Azerbaijani",
20+
"ben-Beng": "Bengali",
21+
"ben-Beng-red": "Bengali (reduced)",
22+
"ben-Beng-east": "Eastern Bengali",
23+
"bho-Deva": "Bhojpuri",
24+
"bxk-Latn": "Bukusu",
25+
"cat-Latn": "Catalan",
26+
"ceb-Latn": "Cebuano",
27+
"ces-Latn": "Czech",
28+
"cjy-Latn": "Jin (Wiktionary)",
29+
"ckb-Arab": "Sorani",
30+
"cmn-Hans": "Mandarin (Simplified)*",
31+
"cmn-Hant": "Mandarin (Traditional)*",
32+
"cmn-Latn": "Mandarin (Pinyin)*",
33+
"csb-Latn": "Kashubian",
34+
"deu-Latn": "German",
35+
"deu-Latn-np": "German†",
36+
"deu-Latn-nar": "German (more phonetic)",
37+
"eng-Latn": "English",
38+
"epo-Latn": "Esperanto",
39+
"est-Latn": "Estonian",
40+
"fas-Arab": "Farsi (Perso-Arabic)",
41+
"fin-Latn": "Finnish",
42+
"fra-Latn": "French",
43+
"fra-Latn-np": "French†",
44+
"fra-Latn-p": "French (more phonetic)",
45+
"ful-Latn": "Fulah",
46+
"gan-Latn": "Gan (Wiktionary)",
47+
"glg-Latn": "Galician",
48+
"got-Goth": "Gothic",
49+
"got-Latn": "Gothic (Latin)",
50+
"hak-Latn": "Hakka (pha̍k-fa-sṳ)",
51+
"hat-Latn-bab": "Haitian (Latin-Babel)",
52+
"hau-Latn": "Hausa",
53+
"hin-Deva": "Hindi",
54+
"hmn-Latn": "Hmong",
55+
"hrv-Latn": "Croatian",
56+
"hsn-Latn": "Xiang (Wiktionary)",
57+
"hun-Latn": "Hungarian",
58+
"ilo-Latn": "Ilocano",
59+
"ind-Latn": "Indonesian",
60+
"ita-Latn": "Italian",
61+
"jam-Latn": "Jamaican",
62+
"jav-Latn": "Javanese",
63+
"jpn-Hrgn": "Japanese (Hiragana)",
64+
"jpn-Hrgn-red": "Japanese (Hiragana, reduced)",
65+
"jpn-Ktkn": "Japanese (Katakana)",
66+
"jpn-Ktkn-red": "Japanese (Katakana, reduced)",
67+
"jpn-Jpan": "Japanese (Hiragana, Katakana, Kanji)",
68+
"jpn-Hira": "Japanese (Hiragana)",
69+
"jpn-Hira-red": "Japanese (Hiragana, reduced)",
70+
"jpn-Kana": "Japanese (Katakana)",
71+
"jpn-Kana-red": "Japanese (Katakana, reduced)",
72+
"kat-Geor": "Georgian",
73+
"kaz-Cyrl": "Kazakh (Cyrillic)",
74+
"kaz-Cyrl-bab": "Kazakh (Cyrillic—Babel)",
75+
"kaz-Latn": "Kazakh (Latin)",
76+
"kbd-Cyrl": "Kabardian",
77+
"khm-Khmr": "Khmer",
78+
"kin-Latn": "Kinyarwanda",
79+
"kir-Arab": "Kyrgyz (Perso-Arabic)",
80+
"kir-Cyrl": "Kyrgyz (Cyrillic)",
81+
"kir-Latn": "Kyrgyz (Latin)",
82+
"kmr-Latn": "Kurmanji",
83+
"kmr-Latn-red": "Kurmanji (reduced)",
84+
"kor-Hang": "Korean",
85+
"lao-Laoo": "Lao",
86+
"lao-Laoo-prereform": "Lao (Before spelling reform)",
87+
"lav-Latn": "Latvian",
88+
"lez-Cyrl": "Lezgian",
89+
"lij-Latn": "Ligurian",
90+
"lit-Latn": "Lithuanian",
91+
"lsm-Latn": "Saamia",
92+
"ltc-Latn-bax": "Middle Chinese (Baxter and Sagart 2014)",
93+
"lug-Latn": "Ganda / Luganda",
94+
"mal-Mlym": "Malayalam",
95+
"mar-Deva": "Marathi",
96+
"mlt-Latn": "Maltese",
97+
"mon-Cyrl-bab": "Mongolian (Cyrillic)",
98+
"mri-Latn": "Maori",
99+
"msa-Latn": "Malay",
100+
"mya-Mymr": "Burmese",
101+
"nan-Latn": "Hokkien (pe̍h-oē-jī)",
102+
"nan-Latn-tl": "Hokkien (Tâi-lô)",
103+
"nld-Latn": "Dutch",
104+
"nya-Latn": "Chichewa",
105+
"ood-Latn-alv": "Tohono O'odham (Alvarez-Hale)",
106+
"ood-Latn-sax": "Tohono O'odham (Saxton)",
107+
"ori-Orya": "Odia",
108+
"orm-Latn": "Oromo",
109+
"pan-Guru": "Punjabi (Eastern)",
110+
"pol-Latn": "Polish",
111+
"por-Latn": "Portuguese",
112+
"quy-Latn": "Ayacucho Quechua / Quechua Chanka",
113+
"ron-Latn": "Romanian",
114+
"run-Latn": "Rundi",
115+
"rus-Cyrl": "Russian",
116+
"sag-Latn": "Sango",
117+
"sin-Sinh": "Sinhala",
118+
"slv-Latn": "Slovene / Slovenian",
119+
"sna-Latn": "Shona",
120+
"som-Latn": "Somali",
121+
"spa-Latn": "Spanish",
122+
"spa-Latn-eu": "Spanish (Iberian)",
123+
"sqi-Latn": "Albanian",
124+
"sro-Latn": "Sardinian (Campidanese)",
125+
"srp-Latn": "Serbian (Latin)",
126+
"srp-Cyrl": "Serbian (Cyrillic)",
127+
"swa-Latn": "Swahili",
128+
"swa-Latn-red": "Swahili (reduced)",
129+
"swe-Latn": "Swedish",
130+
"tam-Taml": "Tamil",
131+
"tam-Taml-red": "Tamil (reduced)",
132+
"tel-Telu": "Telugu",
133+
"tgk-Cyrl": "Tajik",
134+
"tgl-Latn": "Tagalog",
135+
"tgl-Latn-red": "Tagalog (reduced)",
136+
"tha-Thai": "Thai",
137+
"tir-Ethi": "Tigrinya",
138+
"tir-Ethi-pp": "Tigrinya (more phonemic)",
139+
"tir-Ethi-red": "Tigrinya (reduced)",
140+
"tok-Latn": "Toki Pona",
141+
"tpi-Latn": "Tok Pisin",
142+
"tuk-Cyrl": "Turkmen (Cyrillic)",
143+
"tuk-Latn": "Turkmen (Latin)",
144+
"tur-Latn": "Turkish (Latin)",
145+
"tur-Latn-bab": "Turkish (Latin—Babel)",
146+
"tur-Latn-red": "Turkish (reduced)",
147+
"ukr-Cyrl": "Ukrainian",
148+
"urd-Arab": "Urdu",
149+
"uig-Arab": "Uyghur (Perso-Arabic)",
150+
"uzb-Cyrl": "Uzbek (Cyrillic)",
151+
"uzb-Latn": "Uzbek (Latin)",
152+
"vie-Latn": "Vietnamese",
153+
"wuu-Latn": "Shanghainese Wu (Wiktionary)",
154+
"xho-Latn": "Xhosa",
155+
"yor-Latn": "Yoruba",
156+
"yue-Latn": "Cantonese (Jyutping)",
157+
"yue-Hant": "Cantonese (Character)",
158+
"zha-Latn": "Zhuang",
159+
"zul-Latn": "Zulu",
160+
}

backend/mnemorai/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,4 @@ async def generate_mnemonic_img(
8585

8686
if __name__ == "__main__":
8787
pipeline = MnemonicPipeline()
88-
print(asyncio.run(pipeline.generate_mnemonic_img("ratatouille", "eng-us")))
88+
print(asyncio.run(pipeline.generate_mnemonic_img("tikus", "ind")))

backend/mnemorai/services/imagine/image_gen.py

Lines changed: 101 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@
33

44
import torch
55
from diffusers import (
6+
AutoencoderKL,
67
AutoPipelineForText2Image,
7-
SanaPipeline,
8-
SanaTransformer2DModel,
8+
FlowMatchEulerDiscreteScheduler,
9+
FluxPipeline,
910
)
10-
from diffusers import (
11-
BitsAndBytesConfig as DiffusersBitsAndBytesConfig,
12-
)
13-
from transformers import AutoModel
11+
from huggingface_hub import hf_hub_download
12+
from nunchaku import NunchakuT5EncoderModel
13+
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
14+
from nunchaku.utils import get_precision
1415
from transformers import BitsAndBytesConfig as BitsAndBytesConfig
16+
from transformers import CLIPTextModel, CLIPTokenizer, T5TokenizerFast
1517

1618
from mnemorai.constants.config import config
1719
from mnemorai.logger import logger
@@ -37,69 +39,117 @@ def __init__(self, model: str = None):
3739
os.makedirs(self.output_dir, exist_ok=True)
3840
self.image_gen_params = self.config.get("PARAMS", {})
3941

42+
# if seed is provided, set it
43+
if "seed" in self.image_gen_params:
44+
if not isinstance(self.image_gen_params["seed"], int):
45+
logger.warning("Seed must be an integer. Using no seed.")
46+
else:
47+
self.image_gen_params["generator"] = torch.Generator(
48+
device="cuda"
49+
).manual_seed(self.image_gen_params["seed"])
50+
# remove seed from params to avoid passing it to the pipeline
51+
del self.image_gen_params["seed"]
52+
4053
# Initialize pipe to None; will be loaded on first use
4154
self.pipe = None
4255

4356
def _get_pipe_func(self):
44-
if "sana" in self.model_name.lower():
45-
return SanaPipeline
57+
if "flux" in self.model_name.lower():
58+
return FluxPipeline
4659
else:
4760
return AutoPipelineForText2Image
4861

4962
def _initialize_pipe(self):
5063
"""Initialize the pipeline."""
5164
pipe_func = self._get_pipe_func()
5265
logger.debug(f"Initializing pipeline for model: {self.model}")
66+
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
5367

54-
quantization = self.config.get("QUANTIZATION")
68+
if "flux" in self.model_name.lower():
69+
bfl_repo = "black-forest-labs/FLUX.1-dev"
70+
device = "cuda"
5571

56-
if quantization != "4bit" and quantization != "8bit":
57-
logger.debug("Using default model loading without quantization")
58-
self.pipe = pipe_func.from_pretrained(
59-
self.model,
60-
torch_dtype=torch.float16,
61-
variant="fp16",
72+
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
73+
bfl_repo,
74+
subfolder="scheduler",
75+
torch_dtype=dtype,
6276
cache_dir="models",
6377
)
64-
else:
65-
if "sana" in self.model_name.lower():
66-
if quantization == "8bit":
67-
quant_config = BitsAndBytesConfig(load_in_8bit=True)
68-
logger.debug("Using 8-bit quantization for Sana model")
69-
elif quantization == "4bit":
70-
quant_config = BitsAndBytesConfig(load_in_4bit=True)
71-
logger.debug("Using 4-bit quantization for Sana model")
72-
else:
73-
raise ValueError(
74-
f"Invalid quantization type. Use '8bit' or '4bit'. Your quantization is: {quantization}"
78+
text_encoder = CLIPTextModel.from_pretrained(
79+
bfl_repo,
80+
subfolder="text_encoder",
81+
torch_dtype=dtype,
82+
cache_dir="models",
83+
)
84+
# T5 encoder in int4
85+
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained(
86+
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors",
87+
cache_dir="models",
88+
)
89+
tokenizer = CLIPTokenizer.from_pretrained(
90+
bfl_repo,
91+
subfolder="tokenizer",
92+
torch_dtype=dtype,
93+
clean_up_tokenization_spaces=True,
94+
cache_dir="models",
95+
)
96+
tokenizer_2 = T5TokenizerFast.from_pretrained(
97+
bfl_repo,
98+
subfolder="tokenizer_2",
99+
torch_dtype=dtype,
100+
clean_up_tokenization_spaces=True,
101+
cache_dir="models",
102+
)
103+
vae = AutoencoderKL.from_pretrained(
104+
bfl_repo,
105+
subfolder="vae",
106+
torch_dtype=dtype,
107+
cache_dir="models",
108+
)
109+
precision = (
110+
get_precision()
111+
) # auto-detect your precision is 'int4' or 'fp4' based on your GPU
112+
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
113+
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
114+
offload=self.config.get("OFFLOAD_T5", True),
115+
)
116+
# Set attention implementation to fp16
117+
transformer.set_attention_impl("nunchaku-fp16")
118+
119+
params = {
120+
"scheduler": scheduler,
121+
"vae": vae,
122+
"tokenizer": tokenizer,
123+
"tokenizer_2": tokenizer_2,
124+
"text_encoder": text_encoder,
125+
"text_encoder_2": text_encoder_2,
126+
"transformer": transformer,
127+
}
128+
self.pipe = FluxPipeline(**params) # .to(device, dtype=dtype)
129+
130+
lora_config = self.config.get("FLUX_LORA", {})
131+
if lora_config.get("USE_LORA", False):
132+
logger.info("Loading LoRA weights for FLUX model.")
133+
transformer.update_lora_params(
134+
hf_hub_download(
135+
lora_config.get("LORA_REPO"),
136+
lora_config.get("LORA_FILE"),
75137
)
76-
77-
text_encoder_8bit = AutoModel.from_pretrained(
78-
self.model,
79-
subfolder="text_encoder",
80-
quantization_config=quant_config,
81-
torch_dtype=torch.float16,
82-
cache_dir="models",
83138
)
84139

85-
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
86-
transformer_8bit = SanaTransformer2DModel.from_pretrained(
87-
self.model,
88-
subfolder="transformer",
89-
quantization_config=quant_config,
90-
torch_dtype=torch.float16,
91-
cache_dir="models",
92-
)
140+
transformer.set_lora_strength(lora_config.get("LORA_SCALE", 1.0))
93141

94-
self.pipe = SanaPipeline.from_pretrained(
95-
self.model,
96-
text_encoder=text_encoder_8bit,
97-
transformer=transformer_8bit,
98-
torch_dtype=torch.float16,
99-
device_map="balanced",
100-
)
101-
else:
102-
raise NotImplementedError("Quantization not supported for this model.")
142+
# offload, does not decrease performance
143+
if self.config.get("SEQUENTIAL_OFFLOAD", True):
144+
logger.info("Enabling sequential CPU offload for FLUX model.")
145+
self.pipe.enable_sequential_cpu_offload(device=device)
146+
else:
147+
self.pipe = pipe_func.from_pretrained(
148+
self.model,
149+
torch_dtype=dtype,
150+
variant="fp16" if dtype == torch.float16 else None,
151+
cache_dir="models",
152+
)
103153

104154
@manage_memory(
105155
targets=["pipe"],
@@ -108,7 +158,7 @@ def _initialize_pipe(self):
108158
)
109159
def generate_img(
110160
self,
111-
prompt: str = "Imagine a flashy bottle that stands out from the other bottles.",
161+
prompt: str = "A flashy bottle that stands out from the other bottles.",
112162
word1: str = "flashy",
113163
word2: str = "bottle",
114164
):
@@ -126,9 +176,6 @@ def generate_img(
126176
"""
127177
file_path = self.output_dir / f"{word1}_{word2}_{self.model_name}.png"
128178

129-
# Clean prompt by dropping "imagine " prefix
130-
prompt = prompt.lower().lstrip("imagine").strip()
131-
132179
logger.info(f"Generating image for prompt: {prompt}")
133180
image = self.pipe(prompt=prompt, **self.image_gen_params).images[0]
134181
logger.info(f"Saving image to: {file_path}")

0 commit comments

Comments
 (0)