Skip to content

Commit b96cfc7

Browse files
committed
fix
Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent 930ff73 commit b96cfc7

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

tests/functional_tests/hf_transformer/test_formatting_utils_options.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,14 @@ def test_format_prompt_completion_options(seq_length, padding, truncation):
9696

9797
# Attention mask should have zeros only in padded tail (if any)
9898
if isinstance(seq_length, int):
99-
# From the end, once we see a 0, the rest must be 0
100-
seen_zero = False
99+
# From the end, once we see a non-zero, no zeros should appear (right padding)
100+
seen_nonzero = False
101101
for v in reversed(out["attention_mask"]):
102-
if v == 0:
103-
seen_zero = True
102+
if v != 0:
103+
seen_nonzero = True
104104
else:
105-
if seen_zero:
106-
pytest.fail("Non-zero attention_mask value after padded zeros.")
105+
if seen_nonzero:
106+
pytest.fail("Zero attention_mask value before non-padded tokens (padding not only in tail). ")
107107

108108

109109
@pytest.mark.parametrize(
@@ -170,12 +170,13 @@ def test_format_chat_template_options(seq_length, padding, truncation):
170170

171171
# Attention mask padded tail zeros, if padded
172172
if isinstance(seq_length, int) and truncation == False:
173-
seen_zero = False
173+
# From the end, once we see a non-zero, no zeros should appear (right padding)
174+
seen_nonzero = False
174175
for v in reversed(out["attention_mask"]):
175-
if v == 0:
176-
seen_zero = True
176+
if v != 0:
177+
seen_nonzero = True
177178
else:
178-
if seen_zero:
179-
pytest.fail("Non-zero attention_mask value after padded zeros.")
179+
if seen_nonzero:
180+
pytest.fail("Zero attention_mask value before non-padded tokens (padding not only in tail).")
180181

181182

0 commit comments

Comments
 (0)