@@ -160,29 +160,49 @@ def test_wrap_conversation_pipeline():
160160 framework = "pt" ,
161161 )
162162 conv_pipe = wrap_conversation_pipeline (init_pipeline )
163- data = {
164- "past_user_inputs" : ["Which movie is the best ?" ],
165- "generated_responses" : ["It's Die Hard for sure." ],
166- "text" : "Can you explain why?" ,
167- }
163+ data = [
164+ {
165+ "role" : "user" ,
166+ "content" : "Which movie is the best ?"
167+ },
168+ {
169+ "role" : "assistant" ,
170+ "content" : "It's Die Hard for sure."
171+ },
172+ {
173+ "role" : "user" ,
174+ "content" : "Can you explain why?"
175+ }
176+ ]
168177 res = conv_pipe (data )
169- assert "conversation" in res
170- assert "generated_text" in res
178+ assert "content" in res .messages [- 1 ]
171179
172180
173181@require_torch
174182def test_wrapped_pipeline ():
175183 with tempfile .TemporaryDirectory () as tmpdirname :
176- storage_dir = _load_repository_from_hf ("hf-internal-testing/tiny-random-blenderbot" , tmpdirname , framework = "pytorch" )
184+ storage_dir = _load_repository_from_hf (
185+ repository_id = "microsoft/DialoGPT-small" ,
186+ target_dir = tmpdirname ,
187+ framework = "pytorch"
188+ )
177189 conv_pipe = get_pipeline ("conversational" , storage_dir .as_posix ())
178- data = {
179- "past_user_inputs" : ["Which movie is the best ?" ],
180- "generated_responses" : ["It's Die Hard for sure." ],
181- "text" : "Can you explain why?" ,
182- }
190+ data = [
191+ {
192+ "role" : "user" ,
193+ "content" : "Which movie is the best ?"
194+ },
195+ {
196+ "role" : "assistant" ,
197+ "content" : "It's Die Hard for sure."
198+ },
199+ {
200+ "role" : "user" ,
201+ "content" : "Can you explain why?"
202+ }
203+ ]
183204 res = conv_pipe (data )
184- assert "conversation" in res
185- assert "generated_text" in res
205+ assert "content" in res .messages [- 1 ]
186206
187207
188208def test_local_custom_pipeline ():
0 commit comments