Skip to content

Commit 3b9ac65

Browse files
authored
🖨️ Print rich table for messages (#4160)
1 parent 7a78320 commit 3b9ac65

File tree

2 files changed

+122
-7
lines changed

2 files changed

+122
-7
lines changed

tests/test_utils.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,94 @@ def test_num_samples(self, mock_stdout):
743743
]
744744
self.assertIn(output, possible_outputs)
745745

746+
@patch("sys.stdout", new_callable=StringIO)
747+
def test_print_messages(self, mock_stdout):
748+
prompts = [
749+
[
750+
{"role": "system", "content": "You are an helpful assistant."},
751+
{"role": "user", "content": "What color is the sky?"},
752+
],
753+
[
754+
{"role": "system", "content": "You are an helpful assistant."},
755+
{"role": "user", "content": "Where is the sun?"},
756+
],
757+
]
758+
completions = [
759+
[{"role": "assistant", "content": "It is blue."}],
760+
[{"role": "assistant", "content": "In the sky."}],
761+
]
762+
rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]}
763+
advantages = [0.987, 0.654]
764+
step = 42
765+
766+
print_prompt_completions_sample(prompts, completions, rewards, advantages, step)
767+
768+
output = mock_stdout.getvalue()
769+
770+
# docstyle-ignore
771+
expected_output = textwrap.dedent("""\
772+
╭────────────────────────────────── Step 42 ───────────────────────────────────╮
773+
│ ┏━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓ │
774+
│ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ │
775+
│ ┡━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩ │
776+
│ │ SYSTEM │ ASSISTANT │ 0.12 │ 0.79 │ 0.99 │ │
777+
│ │ You are an helpful │ It is blue. │ │ │ │ │
778+
│ │ assistant. │ │ │ │ │ │
779+
│ │ │ │ │ │ │ │
780+
│ │ USER │ │ │ │ │ │
781+
│ │ What color is the sky? │ │ │ │ │ │
782+
│ ├─────────────────────────┼─────────────┼─────────────┼────────┼───────────┤ │
783+
│ │ SYSTEM │ ASSISTANT │ 0.46 │ 0.10 │ 0.65 │ │
784+
│ │ You are an helpful │ In the sky. │ │ │ │ │
785+
│ │ assistant. │ │ │ │ │ │
786+
│ │ │ │ │ │ │ │
787+
│ │ USER │ │ │ │ │ │
788+
│ │ Where is the sun? │ │ │ │ │ │
789+
│ └─────────────────────────┴─────────────┴─────────────┴────────┴───────────┘ │
790+
╰──────────────────────────────────────────────────────────────────────────────╯
791+
""")
792+
793+
self.assertEqual(output, expected_output)
794+
795+
@patch("sys.stdout", new_callable=StringIO)
796+
def test_print_messages_with_tools(self, mock_stdout):
797+
prompts = [
798+
[{"role": "user", "content": "What is the temperature in Paris?"}],
799+
[{"role": "user", "content": "What is the weather in London?"}],
800+
]
801+
completions = [
802+
[{"role": "tool", "name": "get_temperature", "args": {"location": "Paris"}}],
803+
[{"role": "tool", "name": "get_weather", "args": {"location": "London"}}],
804+
]
805+
rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]}
806+
advantages = [0.987, 0.654]
807+
step = 42
808+
809+
print_prompt_completions_sample(prompts, completions, rewards, advantages, step)
810+
811+
output = mock_stdout.getvalue()
812+
813+
# docstyle-ignore
814+
expected_output = textwrap.dedent("""\
815+
╭────────────────────────────────── Step 42 ───────────────────────────────────╮
816+
│ ┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓ │
817+
│ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ │
818+
│ ┡━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩ │
819+
│ │ USER │ TOOL │ 0.12 │ 0.79 │ 0.99 │ │
820+
│ │ What is the │ get_temperature(… │ │ │ │ │
821+
│ │ temperature in │ 'Paris'}) │ │ │ │ │
822+
│ │ Paris? │ │ │ │ │ │
823+
│ ├───────────────────┼───────────────────┼─────────────┼────────┼───────────┤ │
824+
│ │ USER │ TOOL │ 0.46 │ 0.10 │ 0.65 │ │
825+
│ │ What is the │ get_weather({'lo… │ │ │ │ │
826+
│ │ weather in │ 'London'}) │ │ │ │ │
827+
│ │ London? │ │ │ │ │ │
828+
│ └───────────────────┴───────────────────┴─────────────┴────────┴───────────┘ │
829+
╰──────────────────────────────────────────────────────────────────────────────╯
830+
""")
831+
832+
self.assertEqual(output, expected_output)
833+
746834

747835
class TestSelectiveLogSoftmax(TrlTestCase):
748836
@parameterized.expand([(torch.float64,), (torch.float32,), (torch.float16,), (torch.bfloat16,)])

trl/trainer/utils.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,8 +1528,8 @@ def entropy_from_logits(logits: torch.Tensor, chunk_size: int = 128) -> torch.Te
15281528

15291529

15301530
def print_prompt_completions_sample(
1531-
prompts: list[str],
1532-
completions: list[str],
1531+
prompts: list,
1532+
completions: list,
15331533
rewards: dict[str, list[float]],
15341534
advantages: list[float],
15351535
step: int,
@@ -1542,10 +1542,10 @@ def print_prompt_completions_sample(
15421542
during training. It requires the `rich` library to be installed.
15431543
15441544
Args:
1545-
prompts (`list[str]`):
1546-
List of prompts.
1547-
completions (`list[str]`):
1548-
List of completions corresponding to the prompts.
1545+
prompts (`list`):
1546+
List of prompts. Can be either strings or lists of messages.
1547+
completions (`list`):
1548+
List of completions corresponding to the prompts. Can be either strings or lists of messages.
15491549
rewards (`dict[str, list[float]]`):
15501550
Dictionary where keys are reward names and values are lists of rewards.
15511551
advantages (`list[float]`):
@@ -1590,6 +1590,28 @@ def print_prompt_completions_sample(
15901590
table.add_column(reward_name, style="bold cyan", justify="right")
15911591
table.add_column("Advantage", style="bold magenta", justify="right")
15921592

1593+
def format_entry(entry) -> Text:
1594+
t = Text()
1595+
if isinstance(entry, list) and all(isinstance(m, dict) for m in entry):
1596+
for j, msg in enumerate(entry):
1597+
role = msg.get("role", "")
1598+
if "content" in msg:
1599+
# Chat message
1600+
t.append(f"{role.upper()}\n", style="bold red")
1601+
t.append(msg["content"])
1602+
elif "name" in msg and "args" in msg:
1603+
# Tool call
1604+
t.append(f"{role.upper()}\n", style="bold red")
1605+
t.append(f"{msg['name']}({msg['args']})")
1606+
else:
1607+
# Fallback
1608+
t.append(str(msg))
1609+
if j < len(entry) - 1:
1610+
t.append("\n\n")
1611+
else:
1612+
t.append(str(entry))
1613+
return t
1614+
15931615
# Some basic input validation
15941616
if num_samples is not None:
15951617
if num_samples >= len(prompts):
@@ -1607,7 +1629,12 @@ def print_prompt_completions_sample(
16071629

16081630
for i in range(len(prompts)):
16091631
reward_values = [f"{rewards[key][i]:.2f}" for key in rewards.keys()] # 2 decimals
1610-
table.add_row(Text(prompts[i]), Text(completions[i]), *reward_values, f"{advantages[i]:.2f}")
1632+
table.add_row(
1633+
format_entry(prompts[i]),
1634+
format_entry(completions[i]),
1635+
*reward_values,
1636+
f"{advantages[i]:.2f}",
1637+
)
16111638
table.add_section() # Adds a separator between rows
16121639

16131640
panel = Panel(table, expand=False, title=f"Step {step}", border_style="bold white")

0 commit comments

Comments
 (0)