-
Notifications
You must be signed in to change notification settings - Fork 3.1k
【Hackathon 4th No.104】generate API对齐HF #5587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 4 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
e712eee
style error
zwhong714 ab74e34
fix bug
zwhong714 a2ea337
resort and add test but may exist bug
zwhong714 918e0b5
fix bugs(pass all test), update RFC, and will add generation-utils te…
zwhong714 9d99f90
align with transformers, NEED UNITEST on encoder-decoder arch, may ex…
zwhong714 1d3f7c8
final-version
zwhong714 ea134e3
delete some comments
zwhong714 fc071b1
merge done!
zwhong714 8bb863f
add slow
zwhong714 42da2c9
remove test on generation_utils
zwhong714 7c48a4a
not done
zwhong714 0aea313
5/12
zwhong714 bf36005
merge
zwhong714 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,180 @@ | ||
| # paddlenlp.transforms.generation_utils.py | ||
|
|
||
|
|
||
| |API名称 | 新增API名称 | | ||
| |---|---| | ||
| |提交作者<input type="checkbox" class="rowselector hidden"> | Wenhong Zhu | | ||
| |提交时间<input type="checkbox" class="rowselector hidden"> | 2023-03-17 | | ||
| |版本号 | V1.0 | | ||
| |依赖飞桨版本<input type="checkbox" class="rowselector hidden"> | Paddlenlp v2.5 | | ||
| |文件名 | 20230317_api_design_for_generation.md<br> | | ||
|
|
||
|
|
||
| # 一、概述 | ||
| ## 1、相关背景 | ||
|
|
||
| 文本解码算法在文本生成有着重要的作用,一个好的解码算法可以在文本生成过程中降低生成文本的重复率和语言退化等现象。近年来,贪心搜索,束搜索,top-k, top-p等一些常用的解码算法在生成任务中的问题日益突出,一些优秀的解码算法不断被提出,比如本次新增的contrastive search方法,可以保证模型在生成过程中,保留高度的语义相关性,并且使得文本生成更加多样性。随着HuggingFace开源社区的不断壮大,一些接口需要和HuggingFace对齐,不断扩大飞桨框架的影响力。 | ||
|
|
||
|
|
||
| ## 2、功能目标 | ||
|
|
||
| PaddleNLP[生成式API](https://github.com/PaddlePaddle/PaddleNLP/blob/v2.5.0/paddlenlp/transformers/generation_utils.py)功能对齐HuggingFace Transformers,重构[sample](https://github.com/huggingface/transformers/blob/v4.26.1/src/transformers/generation/utils.py#L2259),新增[contrastive_search](https://github.com/huggingface/transformers/blob/v4.26.1/src/transformers/generation/utils.py#L1659) | ||
|
|
||
| ## 3、意义 | ||
|
|
||
| 丰富解码算法种类,对齐HuggingFace开源组件,同时提高paddleNLP的影响力。 | ||
|
|
||
| # 二、飞桨现状 | ||
|
|
||
| 目前,随着NLP的迅速发展,很多新产生的算法被提出,包括typial-decoding, eta-decoding 等新兴算法,这些很难和目前的PaddleNLP生成式API结合进来,对PaddleNLP来说,往往需要新增和删改大量的代码,同时一些函数的参数过于臃肿,这些其实可以优化整合在一起。对于contrastive search刚提出的方法,目前还没有可以被替代的API,需要新增。 | ||
|
|
||
|
|
||
| # 三、业内方案调研 | ||
|
|
||
| HuggingFace是目前最大的自然语言处理模型的开源社区,其一些优秀的代码设计值得借鉴。 | ||
|
|
||
|
|
||
| # 四、对比分析 | ||
|
|
||
| 在HuggingFace 中的生成是API设计中,优点如下: | ||
| - 所有对预测概率分布调整的函数,都被统一到一个类中去,更易理解,规范和新增组件 | ||
| - 模型迭代停止的条件也统一到一个类中去,更易理解,规范和新增组件 | ||
| - 对一些解码算法公用的参数,统一到generation中去,不在具体的解码算法中体现 | ||
|
|
||
|
|
||
|
|
||
| # 五、设计思路与实现方案 | ||
|
|
||
| - 对照HuggingFace 中的设计,进行重构 | ||
| - 其中,新增contrastive search. | ||
|
|
||
|
|
||
| ## 命名与参数设计 | ||
|
|
||
| 与paddleNLP保持一致。 | ||
|
|
||
|
|
||
| ## API实现方案 | ||
|
|
||
| 解码目标函数: | ||
| $$ | ||
| x_t=\underset{v \in V^{(k)}}{\arg \max }\{(1-\alpha) \times \underbrace{p_\theta\left(v \mid \boldsymbol{x}_{<t}\right)}_{\text {model confidence }}-\alpha \times \underbrace{\left(\max \left\{s\left(h_v, h_{x_j}\right): 1 \leq j \leq t-1\right\}\right)}_{\text {degeneration penalty }}\} | ||
| $$ | ||
|
|
||
| - 新增logits_process.py 文件:存放对预测token概率分布的调整函数 | ||
| - 新增stopping_criteria.py 文件:存放对停止条件判断的函数 | ||
|
|
||
| ```python | ||
| def sample( | ||
| self, | ||
| input_ids: paddle.Tensor, | ||
| logits_processors: Optional[LogitsProcessorList] = None, | ||
| max_length: Optional[int] = None, | ||
| pad_token_id: Optional[int] = None, | ||
| eos_token_id: Optional[int] = None, | ||
| top_k=None, | ||
| top_p=None, | ||
| temperature=None, | ||
| min_tokens_to_keep=1, | ||
| **model_kwargs | ||
| ): | ||
| pass | ||
| def contrastive_search( | ||
| self, | ||
| input_ids, | ||
| logits_processors, | ||
| max_length, | ||
| pad_token_id, | ||
| eos_token_id, | ||
| penalty_alpha: Optional[float] = 0, | ||
| top_k=None, | ||
| temperature=None, | ||
| **model_kwargs | ||
| ): | ||
| ``` | ||
|
|
||
|
|
||
|
|
||
| # 六、测试和验收的考量 | ||
| ```python | ||
|
|
||
| import paddle | ||
| from paddlenlp.transformers import ( | ||
| UnifiedTransformerLMHeadModel, | ||
| UnifiedTransformerTokenizer, | ||
| ) | ||
| paddle.seed(2) | ||
|
|
||
| # Initialize the model and tokenizer | ||
| model_name_or_path = 'unified_transformer-12L-cn-luge' | ||
| model = UnifiedTransformerLMHeadModel.from_pretrained(model_name_or_path) | ||
| tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name_or_path) | ||
|
|
||
| # Prepare the model inputs. | ||
| history = "早上好,今天空气质量不错。" | ||
| inputs = tokenizer.dialogue_encode(history, task_type='chitchat', | ||
| add_start_token_as_response=True, return_tensors=True) | ||
|
|
||
|
|
||
| result = model.generate( | ||
| input_ids=inputs['input_ids'], | ||
| token_type_ids=inputs['token_type_ids'], | ||
| position_ids=inputs['position_ids'], | ||
| attention_mask=inputs['attention_mask'], | ||
| decode_strategy="sampling", | ||
| penalty_alpha=0.7, | ||
| top_k=6, | ||
| ) | ||
| ids = result[0] | ||
| response = [] | ||
| for sequence_ids in ids.numpy().tolist(): | ||
| #sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)] # BUG: this sentence may have no sep_token_id | ||
| text = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False) | ||
| response.append(text) | ||
| print(response) | ||
|
|
||
|
|
||
| result = model.generate( | ||
| input_ids=inputs['input_ids'], | ||
| token_type_ids=inputs['token_type_ids'], | ||
| position_ids=inputs['position_ids'], | ||
| attention_mask=inputs['attention_mask'], | ||
| decode_strategy="contrastive_search", | ||
| penalty_alpha=0.7, | ||
| top_k=6, | ||
| ) | ||
| ids = result[0] | ||
| response = [] | ||
| for sequence_ids in ids.numpy().tolist(): | ||
| #sequence_ids = sequence_ids[:sequence_ids.index(tokenizer.sep_token_id)] # BUG: this sentence may have no sep_token_id | ||
| text = tokenizer.convert_ids_to_string(sequence_ids, keep_space=False) | ||
| response.append(text) | ||
| print(response) | ||
|
|
||
|
|
||
|
|
||
| ``` | ||
|
|
||
|
|
||
| # 七、可行性分析和排期规划 | ||
|
|
||
| 2/26完成sample调研 | ||
|
|
||
| 3/1 完成sample 重构 | ||
|
|
||
| 3/1 完成contrastive search 调研 | ||
|
|
||
| 4/15 完成所有项目 | ||
|
|
||
|
|
||
| # 八、影响面 | ||
|
|
||
| 对其他模块没影响; 期待后需新增解码算法可以更快更好地融入到PaddleNLP的生成式API中 | ||
|
|
||
| # 名词解释 | ||
| 无 | ||
|
|
||
| # 附件及参考资料 | ||
| 参考论文: https://arxiv.org/abs/2210.14140 | ||
|
|
||
| 参考代码:https://github.com/huggingface/transformers/blob/v4.26.1/src/transformers/generation/utils.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,6 +38,7 @@ | |
| datasets, | ||
| embeddings, | ||
| experimental, | ||
| generation, | ||
| layers, | ||
| losses, | ||
| metrics, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.