4747 "outputs" : [],
4848 "source" : [
4949 " # Install required libraries\n " ,
50- " !pip install -q git+https://github.com/huggingface/transformers.git\n " ,
51- " !pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n " ,
52- " !pip install -q git+https://github.com/borisdayma/dalle-mini.git"
50+ " !pip install -q dalle-mini\n " ,
51+ " !pip install -q git+https://github.com/patil-suraj/vqgan-jax.git"
5352 ]
5453 },
5554 {
250249 "id" : " BQ7fymSPyvF_"
251250 },
252251 "source" : [
253- " Let's define a text prompt ."
252+ " Let's define some text prompts ."
254253 ]
255254 },
256255 {
261260 },
262261 "outputs" : [],
263262 "source" : [
264- " prompt = \" sunset over a lake in the mountains\" "
263+ " prompts = [ \" sunset over a lake in the mountains\" , \" the Eiffel tower landing on the moon \" ] "
265264 ]
266265 },
266+ {
267+ "cell_type" : " markdown" ,
268+ "source" : [
269+ " Note: we could use the same prompt multiple times for faster inference."
270+ ],
271+ "metadata" : {
272+ "id" : " XlZUG3SCLnGE"
273+ }
274+ },
267275 {
268276 "cell_type" : " code" ,
269277 "execution_count" : null ,
272280 },
273281 "outputs" : [],
274282 "source" : [
275- " tokenized_prompt = processor([prompt] )"
283+ " tokenized_prompts = processor(prompts )"
276284 ]
277285 },
278286 {
281289 "id" : " -CEJBnuJOe5z"
282290 },
283291 "source" : [
284- " Finally we replicate it onto each device."
292+ " Finally we replicate the prompts onto each device."
285293 ]
286294 },
287295 {
292300 },
293301 "outputs" : [],
294302 "source" : [
295- " tokenized_prompt = replicate(tokenized_prompt )"
303+ " tokenized_prompt = replicate(tokenized_prompts )"
296304 ]
297305 },
298306 {
314322 },
315323 "outputs" : [],
316324 "source" : [
317- " # number of predictions\n " ,
325+ " # number of predictions per prompt \n " ,
318326 " n_predictions = 8\n " ,
319327 " \n " ,
320- " # We can customize generation parameters\n " ,
328+ " # We can customize generation parameters (see https://huggingface.co/blog/how-to-generate) \n " ,
321329 " gen_top_k = None\n " ,
322330 " gen_top_p = None\n " ,
323331 " temperature = None\n " ,
337345 " from PIL import Image\n " ,
338346 " from tqdm.notebook import trange\n " ,
339347 " \n " ,
340- " print(f\" Prompt : {prompt }\\ n\" )\n " ,
348+ " print(f\" Prompts : {prompts }\\ n\" )\n " ,
341349 " # generate images\n " ,
342350 " images = []\n " ,
343351 " for i in trange(max(n_predictions // jax.device_count(), 1)):\n " ,
361369 " for decoded_img in decoded_images:\n " ,
362370 " img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n " ,
363371 " images.append(img)\n " ,
364- " display(img)"
372+ " display(img)\n " ,
373+ " print()"
365374 ]
366375 },
367376 {
415424 " \n " ,
416425 " # get clip scores\n " ,
417426 " clip_inputs = clip_processor(\n " ,
418- " text=[prompt] * jax.device_count(),\n " ,
427+ " text=prompts * jax.device_count(),\n " ,
419428 " images=images,\n " ,
420429 " return_tensors=\" np\" ,\n " ,
421430 " padding=\" max_length\" ,\n " ,
422431 " max_length=77,\n " ,
423432 " truncation=True,\n " ,
424433 " ).data\n " ,
425434 " logits = p_clip(shard(clip_inputs), clip_params)\n " ,
426- " logits = logits.squeeze().flatten()"
435+ " \n " ,
436+ " # organize scores per prompt\n " ,
437+ " p = len(prompts)\n " ,
438+ " logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()\n " ,
439+ " #logits = rearrange(logits, '1 b p -> p b')"
427440 ]
428441 },
442+ {
443+ "cell_type" : " code" ,
444+ "source" : [
445+ " logits.shape"
446+ ],
447+ "metadata" : {
448+ "id" : " ia0302WtRcEO"
449+ },
450+ "execution_count" : null ,
451+ "outputs" : []
452+ },
429453 {
430454 "cell_type" : " markdown" ,
431455 "metadata" : {
443467 },
444468 "outputs" : [],
445469 "source" : [
446- " print(f\" Prompt: {prompt}\\ n\" )\n " ,
447- " for idx in logits.argsort()[::-1]:\n " ,
448- " display(images[idx])\n " ,
449- " print(f\" Score: {logits[idx]:.2f}\\ n\" )"
470+ " for i, prompt in enumerate(prompts):\n " ,
471+ " print(f\" Prompt: {prompt}\\ n\" )\n " ,
472+ " for idx in logits[i].argsort()[::-1]:\n " ,
473+ " display(images[idx*p+i])\n " ,
474+ " print(f\" Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\\ n\" )\n " ,
475+ " print()"
450476 ]
451477 }
452478 ],
479505 },
480506 "nbformat" : 4 ,
481507 "nbformat_minor" : 0
482- }
508+ }
0 commit comments