Skip to content

Commit e58bdf2

Browse files
committed
feat(demo): allow multiple prompts
1 parent 66d8a92 commit e58bdf2

File tree

1 file changed

+45
-19
lines changed

1 file changed

+45
-19
lines changed

tools/inference/inference_pipeline.ipynb

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@
4747
"outputs": [],
4848
"source": [
4949
"# Install required libraries\n",
50-
"!pip install -q git+https://github.com/huggingface/transformers.git\n",
51-
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
52-
"!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
50+
"!pip install -q dalle-mini\n",
51+
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git"
5352
]
5453
},
5554
{
@@ -250,7 +249,7 @@
250249
"id": "BQ7fymSPyvF_"
251250
},
252251
"source": [
253-
"Let's define a text prompt."
252+
"Let's define some text prompts."
254253
]
255254
},
256255
{
@@ -261,9 +260,18 @@
261260
},
262261
"outputs": [],
263262
"source": [
264-
"prompt = \"sunset over a lake in the mountains\""
263+
"prompts = [\"sunset over a lake in the mountains\", \"the Eiffel tower landing on the moon\"]"
265264
]
266265
},
266+
{
267+
"cell_type": "markdown",
268+
"source": [
269+
"Note: we could use the same prompt multiple times for faster inference."
270+
],
271+
"metadata": {
272+
"id": "XlZUG3SCLnGE"
273+
}
274+
},
267275
{
268276
"cell_type": "code",
269277
"execution_count": null,
@@ -272,7 +280,7 @@
272280
},
273281
"outputs": [],
274282
"source": [
275-
"tokenized_prompt = processor([prompt])"
283+
"tokenized_prompts = processor(prompts)"
276284
]
277285
},
278286
{
@@ -281,7 +289,7 @@
281289
"id": "-CEJBnuJOe5z"
282290
},
283291
"source": [
284-
"Finally we replicate it onto each device."
292+
"Finally we replicate the prompts onto each device."
285293
]
286294
},
287295
{
@@ -292,7 +300,7 @@
292300
},
293301
"outputs": [],
294302
"source": [
295-
"tokenized_prompt = replicate(tokenized_prompt)"
303+
"tokenized_prompt = replicate(tokenized_prompts)"
296304
]
297305
},
298306
{
@@ -314,10 +322,10 @@
314322
},
315323
"outputs": [],
316324
"source": [
317-
"# number of predictions\n",
325+
"# number of predictions per prompt\n",
318326
"n_predictions = 8\n",
319327
"\n",
320-
"# We can customize generation parameters\n",
328+
"# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)\n",
321329
"gen_top_k = None\n",
322330
"gen_top_p = None\n",
323331
"temperature = None\n",
@@ -337,7 +345,7 @@
337345
"from PIL import Image\n",
338346
"from tqdm.notebook import trange\n",
339347
"\n",
340-
"print(f\"Prompt: {prompt}\\n\")\n",
348+
"print(f\"Prompts: {prompts}\\n\")\n",
341349
"# generate images\n",
342350
"images = []\n",
343351
"for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
@@ -361,7 +369,8 @@
361369
" for decoded_img in decoded_images:\n",
362370
" img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n",
363371
" images.append(img)\n",
364-
" display(img)"
372+
" display(img)\n",
373+
" print()"
365374
]
366375
},
367376
{
@@ -415,17 +424,32 @@
415424
"\n",
416425
"# get clip scores\n",
417426
"clip_inputs = clip_processor(\n",
418-
" text=[prompt] * jax.device_count(),\n",
427+
" text=prompts * jax.device_count(),\n",
419428
" images=images,\n",
420429
" return_tensors=\"np\",\n",
421430
" padding=\"max_length\",\n",
422431
" max_length=77,\n",
423432
" truncation=True,\n",
424433
").data\n",
425434
"logits = p_clip(shard(clip_inputs), clip_params)\n",
426-
"logits = logits.squeeze().flatten()"
435+
"\n",
436+
"# organize scores per prompt\n",
437+
"p = len(prompts)\n",
438+
"logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()\n",
439+
"#logits = rearrange(logits, '1 b p -> p b')"
427440
]
428441
},
442+
{
443+
"cell_type": "code",
444+
"source": [
445+
"logits.shape"
446+
],
447+
"metadata": {
448+
"id": "ia0302WtRcEO"
449+
},
450+
"execution_count": null,
451+
"outputs": []
452+
},
429453
{
430454
"cell_type": "markdown",
431455
"metadata": {
@@ -443,10 +467,12 @@
443467
},
444468
"outputs": [],
445469
"source": [
446-
"print(f\"Prompt: {prompt}\\n\")\n",
447-
"for idx in logits.argsort()[::-1]:\n",
448-
" display(images[idx])\n",
449-
" print(f\"Score: {logits[idx]:.2f}\\n\")"
470+
"for i, prompt in enumerate(prompts):\n",
471+
" print(f\"Prompt: {prompt}\\n\")\n",
472+
" for idx in logits[i].argsort()[::-1]:\n",
473+
" display(images[idx*p+i])\n",
474+
" print(f\"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\\n\")\n",
475+
" print()"
450476
]
451477
}
452478
],
@@ -479,4 +505,4 @@
479505
},
480506
"nbformat": 4,
481507
"nbformat_minor": 0
482-
}
508+
}

0 commit comments

Comments
 (0)