|
17 | 17 |
|
18 | 18 | import pytest |
19 | 19 | import torch |
| 20 | +from PIL import Image as PILImage |
20 | 21 |
|
21 | 22 |
|
22 | 23 | CONVERSATION = [ |
@@ -1042,3 +1043,146 @@ def fake_build_labels(input_ids, conversations, processor_arg): |
1042 | 1043 | # Then input_ids[:, :-1] means labels also become [:, :-1] from the shape matching |
1043 | 1044 | # Final: [20, 30, 40] |
1044 | 1045 | assert batch["labels"].shape[1] == 4 # 5 - 1 = 4 |
| 1046 | + |
| 1047 | + |
| 1048 | +# ============================================================================= |
| 1049 | +# Tests for _ensure_rgb |
| 1050 | +# ============================================================================= |
| 1051 | + |
| 1052 | + |
| 1053 | +class TestEnsureRgb: |
| 1054 | + """Tests for _ensure_rgb helper that converts PIL images to RGB.""" |
| 1055 | + |
| 1056 | + def test_rgba_image_converted_to_rgb(self, collate_mod): |
| 1057 | + img = PILImage.new("RGBA", (4, 4), (255, 0, 0, 128)) |
| 1058 | + conversations = [[ |
| 1059 | + {"role": "user", "content": [{"image": img}]}, |
| 1060 | + ]] |
| 1061 | + collate_mod._ensure_rgb(conversations) |
| 1062 | + assert conversations[0][0]["content"][0]["image"].mode == "RGB" |
| 1063 | + |
| 1064 | + def test_grayscale_image_converted_to_rgb(self, collate_mod): |
| 1065 | + img = PILImage.new("L", (4, 4), 128) |
| 1066 | + conversations = [[ |
| 1067 | + {"role": "user", "content": [{"image": img}]}, |
| 1068 | + ]] |
| 1069 | + collate_mod._ensure_rgb(conversations) |
| 1070 | + assert conversations[0][0]["content"][0]["image"].mode == "RGB" |
| 1071 | + |
| 1072 | + def test_palette_image_converted_to_rgb(self, collate_mod): |
| 1073 | + img = PILImage.new("P", (4, 4)) |
| 1074 | + conversations = [[ |
| 1075 | + {"role": "user", "content": [{"image": img}]}, |
| 1076 | + ]] |
| 1077 | + collate_mod._ensure_rgb(conversations) |
| 1078 | + assert conversations[0][0]["content"][0]["image"].mode == "RGB" |
| 1079 | + |
| 1080 | + def test_rgb_image_unchanged(self, collate_mod): |
| 1081 | + img = PILImage.new("RGB", (4, 4), (255, 0, 0)) |
| 1082 | + conversations = [[ |
| 1083 | + {"role": "user", "content": [{"image": img}]}, |
| 1084 | + ]] |
| 1085 | + collate_mod._ensure_rgb(conversations) |
| 1086 | + result = conversations[0][0]["content"][0]["image"] |
| 1087 | + assert result.mode == "RGB" |
| 1088 | + |
| 1089 | + def test_no_images_passthrough(self, collate_mod): |
| 1090 | + conversations = [[ |
| 1091 | + {"role": "user", "content": [{"type": "text", "text": "Hello"}]}, |
| 1092 | + {"role": "assistant", "content": [{"type": "text", "text": "Hi"}]}, |
| 1093 | + ]] |
| 1094 | + result = collate_mod._ensure_rgb(conversations) |
| 1095 | + assert result == conversations |
| 1096 | + |
| 1097 | + def test_string_content_skipped(self, collate_mod): |
| 1098 | + conversations = [[ |
| 1099 | + {"role": "assistant", "content": "plain string"}, |
| 1100 | + ]] |
| 1101 | + result = collate_mod._ensure_rgb(conversations) |
| 1102 | + assert result[0][0]["content"] == "plain string" |
| 1103 | + |
| 1104 | + def test_empty_conversations(self, collate_mod): |
| 1105 | + assert collate_mod._ensure_rgb([]) == [] |
| 1106 | + |
| 1107 | + def test_multiple_images_in_one_turn(self, collate_mod): |
| 1108 | + rgba = PILImage.new("RGBA", (4, 4)) |
| 1109 | + gray = PILImage.new("L", (4, 4)) |
| 1110 | + rgb = PILImage.new("RGB", (4, 4)) |
| 1111 | + conversations = [[ |
| 1112 | + {"role": "user", "content": [ |
| 1113 | + {"image": rgba}, |
| 1114 | + {"type": "text", "text": "describe these"}, |
| 1115 | + {"image": gray}, |
| 1116 | + {"image": rgb}, |
| 1117 | + ]}, |
| 1118 | + ]] |
| 1119 | + collate_mod._ensure_rgb(conversations) |
| 1120 | + items = conversations[0][0]["content"] |
| 1121 | + assert items[0]["image"].mode == "RGB" |
| 1122 | + assert items[1] == {"type": "text", "text": "describe these"} |
| 1123 | + assert items[2]["image"].mode == "RGB" |
| 1124 | + assert items[3]["image"].mode == "RGB" |
| 1125 | + |
| 1126 | + def test_multiple_conversations(self, collate_mod): |
| 1127 | + img1 = PILImage.new("RGBA", (4, 4)) |
| 1128 | + img2 = PILImage.new("L", (4, 4)) |
| 1129 | + conversations = [ |
| 1130 | + [{"role": "user", "content": [{"image": img1}]}], |
| 1131 | + [{"role": "user", "content": [{"image": img2}]}], |
| 1132 | + ] |
| 1133 | + collate_mod._ensure_rgb(conversations) |
| 1134 | + assert conversations[0][0]["content"][0]["image"].mode == "RGB" |
| 1135 | + assert conversations[1][0]["content"][0]["image"].mode == "RGB" |
| 1136 | + |
| 1137 | + def test_non_image_dict_items_untouched(self, collate_mod): |
| 1138 | + conversations = [[ |
| 1139 | + {"role": "user", "content": [ |
| 1140 | + {"type": "text", "text": "hi"}, |
| 1141 | + {"type": "video", "video": "clip.mp4"}, |
| 1142 | + ]}, |
| 1143 | + ]] |
| 1144 | + result = collate_mod._ensure_rgb(conversations) |
| 1145 | + items = result[0][0]["content"] |
| 1146 | + assert items[0] == {"type": "text", "text": "hi"} |
| 1147 | + assert items[1] == {"type": "video", "video": "clip.mp4"} |
| 1148 | + |
| 1149 | + def test_returns_same_list_object(self, collate_mod): |
| 1150 | + conversations = [[{"role": "user", "content": [{"type": "text", "text": "x"}]}]] |
| 1151 | + result = collate_mod._ensure_rgb(conversations) |
| 1152 | + assert result is conversations |
| 1153 | + |
| 1154 | + |
| 1155 | +class TestDefaultCollateFnEnsureRgb: |
| 1156 | + """Test that default_collate_fn integrates _ensure_rgb correctly.""" |
| 1157 | + |
| 1158 | + def test_rgba_image_converted_before_processing(self, collate_mod, fake_qwen_utils, monkeypatch): |
| 1159 | + monkeypatch.setattr(collate_mod, "HAVE_QWEN_VL_UTILS", True, raising=True) |
| 1160 | + |
| 1161 | + captured_conversations = [] |
| 1162 | + |
| 1163 | + class CapturingProcessor: |
| 1164 | + tokenizer = DummyTokenizer() |
| 1165 | + |
| 1166 | + def apply_chat_template(self, conv_list, **kwargs): |
| 1167 | + for conv in conv_list: |
| 1168 | + for turn in conv: |
| 1169 | + content = turn.get("content") |
| 1170 | + if isinstance(content, list): |
| 1171 | + for item in content: |
| 1172 | + if isinstance(item, dict) and isinstance(item.get("image"), PILImage.Image): |
| 1173 | + captured_conversations.append(item["image"].mode) |
| 1174 | + batch_size = len(conv_list) |
| 1175 | + input_ids = torch.arange(1, 5).unsqueeze(0).repeat(batch_size, 1) |
| 1176 | + pixel_values = torch.ones(batch_size, 3, 64, 64, dtype=torch.float32) |
| 1177 | + return {"input_ids": input_ids, "pixel_values": pixel_values} |
| 1178 | + |
| 1179 | + rgba_img = PILImage.new("RGBA", (4, 4), (255, 0, 0, 128)) |
| 1180 | + conversation = [ |
| 1181 | + {"role": "user", "content": [{"image": rgba_img}, {"type": "text", "text": "describe"}]}, |
| 1182 | + {"role": "assistant", "content": [{"type": "text", "text": "red"}]}, |
| 1183 | + ] |
| 1184 | + |
| 1185 | + processor = CapturingProcessor() |
| 1186 | + collate_mod.default_collate_fn([{"conversation": conversation}], processor) |
| 1187 | + |
| 1188 | + assert captured_conversations == ["RGB"] |
0 commit comments