|
13 | 13 | import pytest |
14 | 14 | import torch |
15 | 15 |
|
16 | | -from forge.data.collate import collate_packed |
| 16 | +from forge.data import CROSS_ENTROPY_IGNORE_IDX |
| 17 | +from forge.data.collate import collate_packed, collate_padded |
17 | 18 | from forge.data.datasets import HfIterableDataset |
18 | 19 | from forge.data.datasets.packed import ( |
19 | 20 | _SUPPORTS_FLEX_ATTENTION, |
@@ -995,3 +996,182 @@ def test_iter_restart_determinism(self, dataset_factory): |
995 | 996 | pack2["document_ids"], |
996 | 997 | msg=f"Pack {i}: document_ids mismatch between iterations", |
997 | 998 | ) |
| 999 | + |
| 1000 | + |
| 1001 | +class TestCollatePadded: |
| 1002 | + """Test collate_padded function""" |
| 1003 | + |
| 1004 | + def test_empty_batch(self): |
| 1005 | + """Test collating an empty batch""" |
| 1006 | + result = collate_padded([]) |
| 1007 | + assert result == {} |
| 1008 | + |
| 1009 | + def test_single_sample(self): |
| 1010 | + """Test collating a single sample""" |
| 1011 | + batch = [ |
| 1012 | + { |
| 1013 | + "tokens": torch.tensor([1, 2, 3]), |
| 1014 | + "labels": torch.tensor([4, 5, 6]), |
| 1015 | + } |
| 1016 | + ] |
| 1017 | + result = collate_padded(batch) |
| 1018 | + |
| 1019 | + assert result["tokens"].shape == (1, 3) |
| 1020 | + assert result["labels"].shape == (1, 3) |
| 1021 | + torch.testing.assert_close(result["tokens"], torch.tensor([[1, 2, 3]])) |
| 1022 | + torch.testing.assert_close(result["labels"], torch.tensor([[4, 5, 6]])) |
| 1023 | + |
| 1024 | + def test_equal_length_samples(self): |
| 1025 | + """Test collating samples with equal lengths""" |
| 1026 | + batch = [ |
| 1027 | + { |
| 1028 | + "tokens": torch.tensor([1, 2, 3]), |
| 1029 | + "labels": torch.tensor([4, 5, 6]), |
| 1030 | + }, |
| 1031 | + { |
| 1032 | + "tokens": torch.tensor([7, 8, 9]), |
| 1033 | + "labels": torch.tensor([10, 11, 12]), |
| 1034 | + }, |
| 1035 | + ] |
| 1036 | + result = collate_padded(batch) |
| 1037 | + |
| 1038 | + assert result["tokens"].shape == (2, 3) |
| 1039 | + assert result["labels"].shape == (2, 3) |
| 1040 | + torch.testing.assert_close( |
| 1041 | + result["tokens"], torch.tensor([[1, 2, 3], [7, 8, 9]]) |
| 1042 | + ) |
| 1043 | + torch.testing.assert_close( |
| 1044 | + result["labels"], torch.tensor([[4, 5, 6], [10, 11, 12]]) |
| 1045 | + ) |
| 1046 | + |
| 1047 | + def test_padding_to_longest(self): |
| 1048 | + """Test padding shorter sequences to the longest in batch""" |
| 1049 | + batch = [ |
| 1050 | + { |
| 1051 | + "tokens": torch.tensor([1, 2]), |
| 1052 | + "labels": torch.tensor([3, 4]), |
| 1053 | + }, |
| 1054 | + { |
| 1055 | + "tokens": torch.tensor([5, 6, 7, 8]), |
| 1056 | + "labels": torch.tensor([9, 10, 11, 12]), |
| 1057 | + }, |
| 1058 | + { |
| 1059 | + "tokens": torch.tensor([13, 14, 15]), |
| 1060 | + "labels": torch.tensor([16, 17, 18]), |
| 1061 | + }, |
| 1062 | + ] |
| 1063 | + result = collate_padded(batch) |
| 1064 | + |
| 1065 | + # All should be padded to length 4 (longest) |
| 1066 | + assert result["tokens"].shape == (3, 4) |
| 1067 | + assert result["labels"].shape == (3, 4) |
| 1068 | + |
| 1069 | + # Check tokens padding (padded with 0) |
| 1070 | + torch.testing.assert_close( |
| 1071 | + result["tokens"], |
| 1072 | + torch.tensor([[1, 2, 0, 0], [5, 6, 7, 8], [13, 14, 15, 0]]), |
| 1073 | + ) |
| 1074 | + |
| 1075 | + # Check labels padding (padded with CROSS_ENTROPY_IGNORE_IDX) |
| 1076 | + torch.testing.assert_close( |
| 1077 | + result["labels"], |
| 1078 | + torch.tensor( |
| 1079 | + [ |
| 1080 | + [3, 4, CROSS_ENTROPY_IGNORE_IDX, CROSS_ENTROPY_IGNORE_IDX], |
| 1081 | + [9, 10, 11, 12], |
| 1082 | + [16, 17, 18, CROSS_ENTROPY_IGNORE_IDX], |
| 1083 | + ] |
| 1084 | + ), |
| 1085 | + ) |
| 1086 | + |
| 1087 | + def test_non_tensor_fields_preserved(self): |
| 1088 | + """Test that non-tensor fields are collected correctly""" |
| 1089 | + batch = [ |
| 1090 | + { |
| 1091 | + "tokens": torch.tensor([1, 2]), |
| 1092 | + "labels": torch.tensor([3, 4]), |
| 1093 | + "metadata": "sample1", |
| 1094 | + }, |
| 1095 | + { |
| 1096 | + "tokens": torch.tensor([5, 6, 7]), |
| 1097 | + "labels": torch.tensor([8, 9, 10]), |
| 1098 | + "metadata": "sample2", |
| 1099 | + }, |
| 1100 | + ] |
| 1101 | + result = collate_padded(batch) |
| 1102 | + |
| 1103 | + assert "metadata" in result |
| 1104 | + assert result["metadata"] == ["sample1", "sample2"] |
| 1105 | + |
| 1106 | + def test_metrics_flattened(self): |
| 1107 | + """Test that metrics lists are flattened""" |
| 1108 | + batch = [ |
| 1109 | + { |
| 1110 | + "tokens": torch.tensor([1, 2]), |
| 1111 | + "labels": torch.tensor([3, 4]), |
| 1112 | + "metrics": [ |
| 1113 | + type("Metric", (), {"key": "loss", "value": 1.0})(), |
| 1114 | + type("Metric", (), {"key": "acc", "value": 0.9})(), |
| 1115 | + ], |
| 1116 | + }, |
| 1117 | + { |
| 1118 | + "tokens": torch.tensor([5, 6, 7]), |
| 1119 | + "labels": torch.tensor([8, 9, 10]), |
| 1120 | + "metrics": [type("Metric", (), {"key": "loss", "value": 2.0})()], |
| 1121 | + }, |
| 1122 | + ] |
| 1123 | + result = collate_padded(batch) |
| 1124 | + |
| 1125 | + assert "metrics" in result |
| 1126 | + # Should be flattened from [[metric1, metric2], [metric3]] to [metric1, metric2, metric3] |
| 1127 | + assert len(result["metrics"]) == 3 |
| 1128 | + |
| 1129 | + def test_different_keys_error(self): |
| 1130 | + """Test that different keys across samples raises ValueError""" |
| 1131 | + batch = [ |
| 1132 | + {"tokens": torch.tensor([1, 2]), "labels": torch.tensor([3, 4])}, |
| 1133 | + {"tokens": torch.tensor([5, 6]), "other_key": torch.tensor([7, 8])}, |
| 1134 | + ] |
| 1135 | + |
| 1136 | + with pytest.raises(ValueError, match="All samples must have the same keys"): |
| 1137 | + collate_padded(batch) |
| 1138 | + |
| 1139 | + def test_generic_tensor_handling(self): |
| 1140 | + """Test that any tensor field gets padded correctly""" |
| 1141 | + batch = [ |
| 1142 | + { |
| 1143 | + "tokens": torch.tensor([1, 2]), |
| 1144 | + "labels": torch.tensor([3, 4]), |
| 1145 | + "custom_tensor": torch.tensor([100, 200, 300]), |
| 1146 | + }, |
| 1147 | + { |
| 1148 | + "tokens": torch.tensor([5, 6, 7, 8]), |
| 1149 | + "labels": torch.tensor([9, 10, 11, 12]), |
| 1150 | + "custom_tensor": torch.tensor([400]), |
| 1151 | + }, |
| 1152 | + ] |
| 1153 | + result = collate_padded(batch) |
| 1154 | + |
| 1155 | + # Tokens padded to length 4 |
| 1156 | + assert result["tokens"].shape == (2, 4) |
| 1157 | + torch.testing.assert_close( |
| 1158 | + result["tokens"], torch.tensor([[1, 2, 0, 0], [5, 6, 7, 8]]) |
| 1159 | + ) |
| 1160 | + |
| 1161 | + # Labels padded to length 4 with CROSS_ENTROPY_IGNORE_IDX |
| 1162 | + assert result["labels"].shape == (2, 4) |
| 1163 | + torch.testing.assert_close( |
| 1164 | + result["labels"], |
| 1165 | + torch.tensor( |
| 1166 | + [ |
| 1167 | + [3, 4, CROSS_ENTROPY_IGNORE_IDX, CROSS_ENTROPY_IGNORE_IDX], |
| 1168 | + [9, 10, 11, 12], |
| 1169 | + ] |
| 1170 | + ), |
| 1171 | + ) |
| 1172 | + |
| 1173 | + # Custom tensor padded to length 3 with 0 |
| 1174 | + assert result["custom_tensor"].shape == (2, 3) |
| 1175 | + torch.testing.assert_close( |
| 1176 | + result["custom_tensor"], torch.tensor([[100, 200, 300], [400, 0, 0]]) |
| 1177 | + ) |
0 commit comments