软件环境
重复问题
错误描述
稳定复现步骤 & 代码
generation_utils.py#865L
现有的逻辑中,对于input_ids与inputs_embeds的适配存在潜在bug。并且prepare_input_ids_for_generation方法入参太少,难以适配。
比如我做encoder_decoder任务,此时同时加上repeation惩罚,此时需要利用到来自encoder的input_ids来计算惩罚,此时我会在generate方法中传入input_ids和input_embeds,但是这个if/elif会强制把input_ids重置为[bos_token_id],是不是这里需要elif input_ids is None and inputs_embeds in model_kwargs呢?
而且,prepare_input_ids_for_generation 基本上没传入什么参数,只有一个bos_token_id,和model_kwargs,没有传入更多的参数,比如generate方法中的入参,导致这个方法根本没有重写的意义,扩展性很差。
比如,在上述需要用到input_ids的场景中,如果该方法也传入了input_ids,那么我将可以直接重写覆盖这个方法,修改构造input_ids的逻辑即可,但是现在的框架逻辑不支持这种扩展,只能子类重写generate,很不友好。建议借鉴huggingface的实现!!!