Skip to content

Commit caaa102

Browse files
authored
add disco diffusion stable diffusion into taskflow (#3198)
* add disco diffusion stable diffusion into taskflow * update num_return_images in taskflow readme * update text2image taskflow readme * rename text2image_generation text_to_image * rename
1 parent aafd5a9 commit caaa102

File tree

5 files changed

+533
-45
lines changed

5 files changed

+533
-45
lines changed

docs/model_zoo/taskflow.md

Lines changed: 112 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ PaddleNLP提供**开箱即用**的产业级NLP预置任务能力,无需训练
4141
| [智能写诗](#智能写诗) | `Taskflow("poetry_generation")` |||| | | 使用最大中文开源CPM模型完成写诗 |
4242
| [开放域对话](#开放域对话) | `Taskflow("dialogue")` |||| | | 十亿级语料训练最强中文闲聊模型PLATO-Mini,支持多轮对话 |
4343
| [代码生成](#代码生成) | `Taskflow("code_generation")` |||| | | 代码生成大模型 |
44-
| [文图生成](#文图生成) | `Taskflow("text2image_generation")` |||| | | 文图生成大模型 |
44+
| [文图生成](#文图生成) | `Taskflow("text_to_image")` |||| | | 文图生成大模型 |
4545
| [文本摘要](#文本摘要) | `Taskflow("text_summarization")` ||||| | 文本摘要大模型 |
4646

4747

@@ -1391,30 +1391,115 @@ from paddlenlp import Taskflow
13911391
>>> from paddlenlp import Taskflow
13921392
# 默认模型为 pai-painter-painting-base-zh
13931393
>>> text_to_image = Taskflow("text_to_image")
1394-
# 单条输入
1395-
>>> images = text_to_image("风阁水帘今在眼,且来先看早梅红")
1396-
# [<PIL.Image.Image image mode=RGB size=2048x256>]
1397-
>>> images[0].save("painting-figure.png")
1398-
# 多条输入
1399-
>>> images = text_to_image(["风阁水帘今在眼,且来先看早梅红", "见说春风偏有贺,露花千朵照庭闹"])
1400-
# [<PIL.Image.Image image mode=RGB size=2048x256>,
1401-
# <PIL.Image.Image image mode=RGB size=2048x256>]
1402-
>>> for i, image in enumerate(images):
1403-
>>> image.save(f"painting-figure_{i}.png")
1404-
# pai-painter-commercial-base-zh模型
1394+
# 单条输入, 默认返回2张图片。
1395+
>>> image_list = text_to_image("风阁水帘今在眼,且来先看早梅红")
1396+
# [[<PIL.Image.Image image mode=RGB size=256x256>], [<PIL.Image.Image image mode=RGB size=256x256>]]
1397+
>>> image_list[0][0].save("painting-figure-1.png")
1398+
>>> image_list[0][1].save("painting-figure-2.png")
1399+
>>> image_list[0][0].argument
1400+
# argument表示生成该图片所使用的参数
1401+
# {'input': '风阁水帘今在眼,且来先看早梅红',
1402+
# 'batch_size': 1,
1403+
# 'seed': 2414128200,
1404+
# 'temperature': 1.0,
1405+
# 'top_k': 32,
1406+
# 'top_p': 1.0,
1407+
# 'condition_scale': 10.0,
1408+
# 'num_return_images': 2,
1409+
# 'use_faster': False,
1410+
# 'use_fp16_decoding': False,
1411+
# 'image_index_in_returned_images': 0}
1412+
#
1413+
# 多条输入, 返回值解释:[[第一个文本返回的第一张图片, 第一个文本返回的第二张图片], [第二个文本返回的第一张图片, 第二个文本返回的第二张图片]]
1414+
>>> image_list = text_to_image(["风阁水帘今在眼,且来先看早梅红", "见说春风偏有贺,露花千朵照庭闹"])
1415+
# [[<PIL.Image.Image image mode=RGB size=256x256>, <PIL.Image.Image image mode=RGB size=256x256>],
1416+
# [<PIL.Image.Image image mode=RGB size=256x256>, <PIL.Image.Image image mode=RGB size=256x256>]]
1417+
>>> for batch_index, batch_image in enumerate(image_list):
1418+
# len(batch_image) == 2 (num_return_images)
1419+
>>> for image_index_in_returned_images, each_image in enumerate(batch_image):
1420+
>>> each_image.save(f"painting-figure_{batch_index}_{image_index_in_returned_images}.png")
1421+
```
1422+
1423+
#### 支持多种模型
1424+
1425+
##### EasyNLP仓库中的pai-painter模型
1426+
```python
14051427
>>> text_to_image = Taskflow("text_to_image", model="pai-painter-commercial-base-zh")
1406-
# 多条输入
1407-
>>> images = text_to_image(["女童套头毛衣打底衫秋冬针织衫童装儿童内搭上衣", "春夏真皮工作鞋女深色软皮久站舒适上班面试职业皮鞋"])
1408-
>>> for i, image in enumerate(images):
1409-
>>> image.save(f"commercial-figure_{i}.png")
1410-
# dalle-mini模型
1428+
>>> image_list = text_to_image(["女童套头毛衣打底衫秋冬针织衫童装儿童内搭上衣", "春夏真皮工作鞋女深色软皮久站舒适上班面试职业皮鞋"])
1429+
>>> for batch_index, batch_image in enumerate(image_list):
1430+
# len(batch_image) == 2 (num_return_images)
1431+
>>> for image_index_in_returned_images, each_image in enumerate(batch_image):
1432+
>>> each_image.save(f"commercial-figure_{batch_index}_{image_index_in_returned_images}.png")
1433+
```
1434+
1435+
##### DALLE-mini模型
1436+
```python
14111437
>>> text_to_image = Taskflow("text_to_image", model="dalle-mini")
1412-
# 多条输入
1413-
>>> images = text_to_image(["New York Skyline with 'Google Research Pizza Cafe' written with fireworks on the sky.", "Dali painting of WALL·E"])
1414-
>>> for i, image in enumerate(images):
1415-
>>> image.save(f"dalle-mini-figure_{i}.png")
1438+
>>> image_list = text_to_image(["New York Skyline with 'Google Research Pizza Cafe' written with fireworks on the sky.", "Dali painting of WALL·E"])
1439+
>>> for batch_index, batch_image in enumerate(image_list):
1440+
# len(batch_image) == 2 (num_return_images)
1441+
>>> for image_index_in_returned_images, each_image in enumerate(batch_image):
1442+
>>> each_image.save(f"dalle-mini-figure_{batch_index}_{image_index_in_returned_images}.png")
1443+
```
1444+
1445+
##### Disco Diffusion模型
1446+
```python
1447+
# 注意,该模型生成速度较慢,在32G的V100上需要10分钟才能生成图片,因此默认返回1张图片。
1448+
>>> text_to_image = Taskflow("text_to_image", model="disco_diffusion_ernie_vil-2.0-base-zh")
1449+
>>> image_list = text_to_image("一幅美丽的睡莲池塘的画,由Adam Paquette在artstation上所做。")
1450+
>>> for batch_index, batch_image in enumerate(image_list):
1451+
>>> for image_index_in_returned_images, each_image in enumerate(batch_image):
1452+
>>> each_image.save(f"disco_diffusion_ernie_vil-2.0-base-zh-figure_{batch_index}_{image_index_in_returned_images}.png")
14161453
```
14171454

1455+
##### Stable Diffusion模型
1456+
```python
1457+
>>> text_to_image = Taskflow("text_to_image", model="CompVis/stable-diffusion-v1-4")
1458+
>>> prompt = [
1459+
"In the morning light,Chinese ancient buildings in the mountains,Magnificent and fantastic John Howe landscape,lake,clouds,farm,Fairy tale,light effect,Dream,Greg Rutkowski,James Gurney,artstation",
1460+
"clouds surround the mountains and Chinese palaces,sunshine,lake,overlook,overlook,unreal engine,light effect,Dream,Greg Rutkowski,James Gurney,artstation"
1461+
]
1462+
>>> image_list = text_to_image(prompt)
1463+
>>> for batch_index, batch_image in enumerate(image_list):
1464+
# len(batch_image) == 2 (num_return_images)
1465+
>>> for image_index_in_returned_images, each_image in enumerate(batch_image):
1466+
>>> each_image.save(f"stable-diffusion-figure_{batch_index}_{image_index_in_returned_images}.png")
1467+
```
1468+
1469+
#### 支持复现生成结果 (以Stable Diffusion模型为例)
1470+
```python
1471+
>>> from paddlenlp import Taskflow
1472+
>>> text_to_image = Taskflow("text_to_image", model="CompVis/stable-diffusion-v1-4")
1473+
>>> prompt = [
1474+
"In the morning light,Chinese ancient buildings in the mountains,Magnificent and fantastic John Howe landscape,lake,clouds,farm,Fairy tale,light effect,Dream,Greg Rutkowski,James Gurney,artstation",
1475+
]
1476+
>>> image_list = text_to_image(prompt)
1477+
>>> for batch_index, batch_image in enumerate(image_list):
1478+
# len(batch_image) == 2 (num_return_images)
1479+
>>> for image_index_in_returned_images, each_image in enumerate(batch_image):
1480+
>>> each_image.save(f"stable-diffusion-figure_{batch_index}_{image_index_in_returned_images}.png")
1481+
# 如果我们想复现promt[0]文本的第二张返回的结果,我们可以首先查看生成该图像所使用的参数信息。
1482+
>>> each_image.argument
1483+
# {'mode': 'text2image',
1484+
# 'seed': 2389376819,
1485+
# 'height': 512,
1486+
# 'width': 512,
1487+
# 'num_inference_steps': 50,
1488+
# 'guidance_scale': 7.5,
1489+
# 'latents': None,
1490+
# 'num_return_images': 1,
1491+
# 'input': 'In the morning light,Chinese ancient buildings in the mountains,Magnificent and fantastic John Howe landscape,lake,clouds,farm,Fairy tale,light effect,Dream,Greg Rutkowski,James Gurney,artstation'}
1492+
# 通过set_argument设置该参数。
1493+
>>> text_to_image.set_argument(each_image.argument)
1494+
>>> new_image = text_to_image(each_image.argument["input"])
1495+
# 查看生成图片的结果,可以发现最终结果与之前的图片相一致。
1496+
>>> new_image[0][0]
1497+
```
1498+
<p align="center">
1499+
<img src="https://user-images.githubusercontent.com/50394665/188396018-284336c0-f85e-442b-a4ff-4238720de121.png" align="middle">
1500+
<p align="center">
1501+
1502+
14181503
#### 图片生成效果展示
14191504
<p align="center">
14201505
<img src="https://user-images.githubusercontent.com/50394665/183386146-9b265304-7294-46fa-896f-1dd90f44ba31.png" align="middle">
@@ -1423,12 +1508,15 @@ from paddlenlp import Taskflow
14231508
<img src="https://user-images.githubusercontent.com/50394665/183386237-b0243ec5-09fe-47cc-9010-bd9b97fda862.png" align="middle">
14241509
<img src="https://user-images.githubusercontent.com/50394665/183387833-0f9ef786-ea62-40e1-a48c-28680d418142.png" align="middle">
14251510
<img src="https://user-images.githubusercontent.com/50394665/183387861-c4029b6c-f2e9-46d0-988f-6989f11a607d.png" align="middle">
1511+
<img src="https://user-images.githubusercontent.com/50394665/188397647-5c3e1804-82dc-4f6e-b7ec-befc15eb1910.png" align="middle" width="35%" height="35%">
1512+
<img src="https://user-images.githubusercontent.com/50394665/188397725-d43f84e7-d9aa-4fe0-a16c-2be1dc8b5c1d.png" align="middle" width="35%" height="35%">
1513+
<img src="https://user-images.githubusercontent.com/50394665/188397881-f2a76c5e-d853-4db0-be83-8ac0c2e0a634.png" align="middle" width="35%" height="35%">
1514+
<img src="https://user-images.githubusercontent.com/50394665/188397927-281402f1-a7f5-404f-9e4c-dc0236ba45ed.png" align="middle" width="35%" height="35%">
14261515
<p align="center">
14271516

14281517
#### 可配置参数说明
1429-
* `model`:可选模型,默认为`pai-painter-painting-base-zh`,支持的模型有`["pai-painter-painting-base-zh", "pai-painter-scenery-base-zh", "pai-painter-commercial-base-zh", "dalle-mini", "dalle-mega-v16", "dalle-mega"]`
1430-
* `batch_size`:批处理大小,请结合机器情况进行调整,默认为1。
1431-
* `num_return_images`:返回图片的数量,默认为8,即8张图片水平拼接形成一张长图。
1518+
* `model`:可选模型,默认为`pai-painter-painting-base-zh`,支持的模型有`["dalle-mini", "dalle-mega", "dalle-mega-v16", "pai-painter-painting-base-zh", "pai-painter-scenery-base-zh", "pai-painter-commercial-base-zh", "CompVis/stable-diffusion-v1-4", "openai/disco-diffusion-clip-vit-base-patch32", "openai/disco-diffusion-clip-rn50", "openai/disco-diffusion-clip-rn101", "disco_diffusion_ernie_vil-2.0-base-zh"]`
1519+
* `num_return_images`:返回图片的数量,默认为2。特例:disco_diffusion模型由于生成速度太慢,因此该模型默认值为1。
14321520

14331521
</div></details>
14341522

docs/source/paddlenlp.taskflow.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ paddlenlp.taskflow
2929
paddlenlp.taskflow.sentiment_analysis
3030
paddlenlp.taskflow.task
3131
paddlenlp.taskflow.taskflow
32-
paddlenlp.taskflow.text2image_generation
32+
paddlenlp.taskflow.text_to_image
3333
paddlenlp.taskflow.text_correction
3434
paddlenlp.taskflow.text_generation
3535
paddlenlp.taskflow.text_similarity
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
text2image\_generation
1+
text\_to\_image
22
================================================
33

4-
.. automodule:: paddlenlp.taskflow.text2image_generation
4+
.. automodule:: paddlenlp.taskflow.text_to_image
55
:members:
66
:no-undoc-members:
77
:show-inheritance:

paddlenlp/taskflow/taskflow.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from .dialogue import DialogueTask
3838
from .information_extraction import UIETask, GPTask
3939
from .code_generation import CodeGenerationTask
40-
from .text_to_image import TextToImageGenerationTask
40+
from .text_to_image import TextToImageGenerationTask, TextToImageDiscoDiffusionTask, TextToImageStableDiffusionTask
4141
from .text_summarization import TextSummarizationTask
4242

4343
warnings.simplefilter(action='ignore', category=Warning, lineno=0, append=False)
@@ -404,6 +404,35 @@
404404
"task_flag": "text_to_image-pai-painter-commercial-base-zh",
405405
"task_priority_path": "pai-painter-commercial-base-zh",
406406
},
407+
"openai/disco-diffusion-clip-vit-base-patch32": {
408+
"task_class":
409+
TextToImageDiscoDiffusionTask,
410+
"task_flag":
411+
"text_to_image-openai/disco-diffusion-clip-vit-base-patch32",
412+
"task_priority_path":
413+
"openai/disco-diffusion-clip-vit-base-patch32",
414+
},
415+
"openai/disco-diffusion-clip-rn50": {
416+
"task_class": TextToImageDiscoDiffusionTask,
417+
"task_flag": "text_to_image-openai/disco-diffusion-clip-rn50",
418+
"task_priority_path": "openai/disco-diffusion-clip-rn50",
419+
},
420+
"openai/disco-diffusion-clip-rn101": {
421+
"task_class": TextToImageDiscoDiffusionTask,
422+
"task_flag": "text_to_image-openai/disco-diffusion-clip-rn101",
423+
"task_priority_path": "openai/disco-diffusion-clip-rn101",
424+
},
425+
"disco_diffusion_ernie_vil-2.0-base-zh": {
426+
"task_class": TextToImageDiscoDiffusionTask,
427+
"task_flag":
428+
"text_to_image-disco_diffusion_ernie_vil-2.0-base-zh",
429+
"task_priority_path": "disco_diffusion_ernie_vil-2.0-base-zh",
430+
},
431+
"CompVis/stable-diffusion-v1-4": {
432+
"task_class": TextToImageStableDiffusionTask,
433+
"task_flag": "text_to_image-CompVis/stable-diffusion-v1-4",
434+
"task_priority_path": "CompVis/stable-diffusion-v1-4",
435+
},
407436
},
408437
"default": {
409438
"model": "pai-painter-painting-base-zh",
@@ -416,6 +445,15 @@
416445
"uie-medical-base", "uie-base-en", "wordtag", "uie-m-large", "uie-m-base"
417446
]
418447

448+
support_argument_list = [
449+
"dalle-mini", "dalle-mega", "dalle-mega-v16",
450+
"pai-painter-painting-base-zh", "pai-painter-scenery-base-zh",
451+
"pai-painter-commercial-base-zh", "CompVis/stable-diffusion-v1-4",
452+
"openai/disco-diffusion-clip-vit-base-patch32",
453+
"openai/disco-diffusion-clip-rn50", "openai/disco-diffusion-clip-rn101",
454+
"disco_diffusion_ernie_vil-2.0-base-zh"
455+
]
456+
419457

420458
class Taskflow(object):
421459
"""
@@ -521,4 +559,8 @@ def interactive_mode(self, max_turn):
521559

522560
def set_schema(self, schema):
523561
assert self.task_instance.model in support_schema_list, 'This method can only be used by the task with the model of uie or wordtag.'
524-
self.task_instance.set_schema(schema)
562+
self.task_instance.set_schema(schema)
563+
564+
def set_argument(self, argument):
565+
assert self.task_instance.model in support_argument_list, 'This method can only be used by the task with the model of text_to_image generation.'
566+
self.task_instance.set_argument(argument)

0 commit comments

Comments
 (0)