File tree Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -797,6 +797,10 @@ def _gen_model_input(
797797 max_new_tokens is not None
798798 ), "max_new_tokens must be specified for Flamingo models"
799799
800+ # Wrap string prompts into a list
801+ if isinstance (prompt , str ):
802+ prompt = [{"role" : "user" , "content" : prompt }]
803+
800804 image_found = False
801805 messages = []
802806 for message in prompt :
@@ -959,8 +963,9 @@ def chat(
959963 max_seq_length = (
960964 text_transformer_args .max_seq_length if text_transformer_args else 2048
961965 )
966+
962967 encoded , batch = self ._gen_model_input (
963- [{ "role" : "user" , "content" : generator_args .prompt }] ,
968+ generator_args .prompt ,
964969 generator_args .image_prompts ,
965970 generator_args .max_new_tokens ,
966971 max_seq_length ,
You can’t perform that action at this time.
0 commit comments