@@ -743,6 +743,94 @@ def test_num_samples(self, mock_stdout):
743743 ]
744744 self .assertIn (output , possible_outputs )
745745
746+ @patch ("sys.stdout" , new_callable = StringIO )
747+ def test_print_messages (self , mock_stdout ):
748+ prompts = [
749+ [
750+ {"role" : "system" , "content" : "You are an helpful assistant." },
751+ {"role" : "user" , "content" : "What color is the sky?" },
752+ ],
753+ [
754+ {"role" : "system" , "content" : "You are an helpful assistant." },
755+ {"role" : "user" , "content" : "Where is the sun?" },
756+ ],
757+ ]
758+ completions = [
759+ [{"role" : "assistant" , "content" : "It is blue." }],
760+ [{"role" : "assistant" , "content" : "In the sky." }],
761+ ]
762+ rewards = {"Correctness" : [0.123 , 0.456 ], "Format" : [0.789 , 0.101 ]}
763+ advantages = [0.987 , 0.654 ]
764+ step = 42
765+
766+ print_prompt_completions_sample (prompts , completions , rewards , advantages , step )
767+
768+ output = mock_stdout .getvalue ()
769+
770+ # docstyle-ignore
771+ expected_output = textwrap .dedent ("""\
772+ ╭────────────────────────────────── Step 42 ───────────────────────────────────╮
773+ │ ┏━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓ │
774+ │ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ │
775+ │ ┡━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩ │
776+ │ │ SYSTEM │ ASSISTANT │ 0.12 │ 0.79 │ 0.99 │ │
777+ │ │ You are an helpful │ It is blue. │ │ │ │ │
778+ │ │ assistant. │ │ │ │ │ │
779+ │ │ │ │ │ │ │ │
780+ │ │ USER │ │ │ │ │ │
781+ │ │ What color is the sky? │ │ │ │ │ │
782+ │ ├─────────────────────────┼─────────────┼─────────────┼────────┼───────────┤ │
783+ │ │ SYSTEM │ ASSISTANT │ 0.46 │ 0.10 │ 0.65 │ │
784+ │ │ You are an helpful │ In the sky. │ │ │ │ │
785+ │ │ assistant. │ │ │ │ │ │
786+ │ │ │ │ │ │ │ │
787+ │ │ USER │ │ │ │ │ │
788+ │ │ Where is the sun? │ │ │ │ │ │
789+ │ └─────────────────────────┴─────────────┴─────────────┴────────┴───────────┘ │
790+ ╰──────────────────────────────────────────────────────────────────────────────╯
791+ """ )
792+
793+ self .assertEqual (output , expected_output )
794+
795+ @patch ("sys.stdout" , new_callable = StringIO )
796+ def test_print_messages_with_tools (self , mock_stdout ):
797+ prompts = [
798+ [{"role" : "user" , "content" : "What is the temperature in Paris?" }],
799+ [{"role" : "user" , "content" : "What is the weather in London?" }],
800+ ]
801+ completions = [
802+ [{"role" : "tool" , "name" : "get_temperature" , "args" : {"location" : "Paris" }}],
803+ [{"role" : "tool" , "name" : "get_weather" , "args" : {"location" : "London" }}],
804+ ]
805+ rewards = {"Correctness" : [0.123 , 0.456 ], "Format" : [0.789 , 0.101 ]}
806+ advantages = [0.987 , 0.654 ]
807+ step = 42
808+
809+ print_prompt_completions_sample (prompts , completions , rewards , advantages , step )
810+
811+ output = mock_stdout .getvalue ()
812+
813+ # docstyle-ignore
814+ expected_output = textwrap .dedent ("""\
815+ ╭────────────────────────────────── Step 42 ───────────────────────────────────╮
816+ │ ┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓ │
817+ │ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ │
818+ │ ┡━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩ │
819+ │ │ USER │ TOOL │ 0.12 │ 0.79 │ 0.99 │ │
820+ │ │ What is the │ get_temperature(… │ │ │ │ │
821+ │ │ temperature in │ 'Paris'}) │ │ │ │ │
822+ │ │ Paris? │ │ │ │ │ │
823+ │ ├───────────────────┼───────────────────┼─────────────┼────────┼───────────┤ │
824+ │ │ USER │ TOOL │ 0.46 │ 0.10 │ 0.65 │ │
825+ │ │ What is the │ get_weather({'lo… │ │ │ │ │
826+ │ │ weather in │ 'London'}) │ │ │ │ │
827+ │ │ London? │ │ │ │ │ │
828+ │ └───────────────────┴───────────────────┴─────────────┴────────┴───────────┘ │
829+ ╰──────────────────────────────────────────────────────────────────────────────╯
830+ """ )
831+
832+ self .assertEqual (output , expected_output )
833+
746834
747835class TestSelectiveLogSoftmax (TrlTestCase ):
748836 @parameterized .expand ([(torch .float64 ,), (torch .float32 ,), (torch .float16 ,), (torch .bfloat16 ,)])
0 commit comments