Skip to content

Commit bc8bfa7

Browse files
cp: fix: ensure image rgb (1319) into r0.3.0 (#1327)
fix: ensure image rgb (#1319) * ensure rgb * add tests --------- Signed-off-by: HuiyingLi <willwin.lee@gmail.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com> Co-authored-by: Huiying <willwin.lee@gmail.com>
1 parent 4f750de commit bc8bfa7

File tree

2 files changed

+159
-1
lines changed

2 files changed

+159
-1
lines changed

nemo_automodel/components/datasets/vlm/collate_fns.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
import logging
3737
from typing import Any, Dict, List, Optional, Sequence, Tuple
3838

39+
from PIL import Image as PILImage
40+
3941
logger = logging.getLogger(__name__)
4042

4143
from nemo_automodel.components.datasets.vlm.utils import default_stop_tokens
@@ -602,6 +604,18 @@ def nemotron_parse_collate_fn(
602604
return batch
603605

604606

607+
def _ensure_rgb(conversations):
608+
"""Convert any PIL images in conversations to RGB to handle RGBA/grayscale inputs."""
609+
for conv in conversations:
610+
for turn in conv:
611+
content = turn.get("content")
612+
if isinstance(content, list):
613+
for item in content:
614+
if isinstance(item, dict) and isinstance(item.get("image"), PILImage.Image):
615+
item["image"] = item["image"].convert("RGB")
616+
return conversations
617+
618+
605619
def default_collate_fn(
606620
examples: Sequence[Dict[str, Any]],
607621
processor,
@@ -611,7 +625,7 @@ def default_collate_fn(
611625
if not HAVE_QWEN_VL_UTILS:
612626
raise ImportError(MISSING_QWEN_VL_UTILS_MSG)
613627

614-
conversations = [example["conversation"] for example in examples]
628+
conversations = _ensure_rgb([example["conversation"] for example in examples])
615629
processor_kwargs = {
616630
"tokenize": True,
617631
"padding": True,

tests/unit_tests/datasets/vlm/test_collate_fns.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import pytest
1919
import torch
20+
from PIL import Image as PILImage
2021

2122

2223
CONVERSATION = [
@@ -1042,3 +1043,146 @@ def fake_build_labels(input_ids, conversations, processor_arg):
10421043
# Then input_ids[:, :-1] means labels also become [:, :-1] from the shape matching
10431044
# Final: [20, 30, 40]
10441045
assert batch["labels"].shape[1] == 4 # 5 - 1 = 4
1046+
1047+
1048+
# =============================================================================
1049+
# Tests for _ensure_rgb
1050+
# =============================================================================
1051+
1052+
1053+
class TestEnsureRgb:
1054+
"""Tests for _ensure_rgb helper that converts PIL images to RGB."""
1055+
1056+
def test_rgba_image_converted_to_rgb(self, collate_mod):
1057+
img = PILImage.new("RGBA", (4, 4), (255, 0, 0, 128))
1058+
conversations = [[
1059+
{"role": "user", "content": [{"image": img}]},
1060+
]]
1061+
collate_mod._ensure_rgb(conversations)
1062+
assert conversations[0][0]["content"][0]["image"].mode == "RGB"
1063+
1064+
def test_grayscale_image_converted_to_rgb(self, collate_mod):
1065+
img = PILImage.new("L", (4, 4), 128)
1066+
conversations = [[
1067+
{"role": "user", "content": [{"image": img}]},
1068+
]]
1069+
collate_mod._ensure_rgb(conversations)
1070+
assert conversations[0][0]["content"][0]["image"].mode == "RGB"
1071+
1072+
def test_palette_image_converted_to_rgb(self, collate_mod):
1073+
img = PILImage.new("P", (4, 4))
1074+
conversations = [[
1075+
{"role": "user", "content": [{"image": img}]},
1076+
]]
1077+
collate_mod._ensure_rgb(conversations)
1078+
assert conversations[0][0]["content"][0]["image"].mode == "RGB"
1079+
1080+
def test_rgb_image_unchanged(self, collate_mod):
1081+
img = PILImage.new("RGB", (4, 4), (255, 0, 0))
1082+
conversations = [[
1083+
{"role": "user", "content": [{"image": img}]},
1084+
]]
1085+
collate_mod._ensure_rgb(conversations)
1086+
result = conversations[0][0]["content"][0]["image"]
1087+
assert result.mode == "RGB"
1088+
1089+
def test_no_images_passthrough(self, collate_mod):
1090+
conversations = [[
1091+
{"role": "user", "content": [{"type": "text", "text": "Hello"}]},
1092+
{"role": "assistant", "content": [{"type": "text", "text": "Hi"}]},
1093+
]]
1094+
result = collate_mod._ensure_rgb(conversations)
1095+
assert result == conversations
1096+
1097+
def test_string_content_skipped(self, collate_mod):
1098+
conversations = [[
1099+
{"role": "assistant", "content": "plain string"},
1100+
]]
1101+
result = collate_mod._ensure_rgb(conversations)
1102+
assert result[0][0]["content"] == "plain string"
1103+
1104+
def test_empty_conversations(self, collate_mod):
1105+
assert collate_mod._ensure_rgb([]) == []
1106+
1107+
def test_multiple_images_in_one_turn(self, collate_mod):
1108+
rgba = PILImage.new("RGBA", (4, 4))
1109+
gray = PILImage.new("L", (4, 4))
1110+
rgb = PILImage.new("RGB", (4, 4))
1111+
conversations = [[
1112+
{"role": "user", "content": [
1113+
{"image": rgba},
1114+
{"type": "text", "text": "describe these"},
1115+
{"image": gray},
1116+
{"image": rgb},
1117+
]},
1118+
]]
1119+
collate_mod._ensure_rgb(conversations)
1120+
items = conversations[0][0]["content"]
1121+
assert items[0]["image"].mode == "RGB"
1122+
assert items[1] == {"type": "text", "text": "describe these"}
1123+
assert items[2]["image"].mode == "RGB"
1124+
assert items[3]["image"].mode == "RGB"
1125+
1126+
def test_multiple_conversations(self, collate_mod):
1127+
img1 = PILImage.new("RGBA", (4, 4))
1128+
img2 = PILImage.new("L", (4, 4))
1129+
conversations = [
1130+
[{"role": "user", "content": [{"image": img1}]}],
1131+
[{"role": "user", "content": [{"image": img2}]}],
1132+
]
1133+
collate_mod._ensure_rgb(conversations)
1134+
assert conversations[0][0]["content"][0]["image"].mode == "RGB"
1135+
assert conversations[1][0]["content"][0]["image"].mode == "RGB"
1136+
1137+
def test_non_image_dict_items_untouched(self, collate_mod):
1138+
conversations = [[
1139+
{"role": "user", "content": [
1140+
{"type": "text", "text": "hi"},
1141+
{"type": "video", "video": "clip.mp4"},
1142+
]},
1143+
]]
1144+
result = collate_mod._ensure_rgb(conversations)
1145+
items = result[0][0]["content"]
1146+
assert items[0] == {"type": "text", "text": "hi"}
1147+
assert items[1] == {"type": "video", "video": "clip.mp4"}
1148+
1149+
def test_returns_same_list_object(self, collate_mod):
1150+
conversations = [[{"role": "user", "content": [{"type": "text", "text": "x"}]}]]
1151+
result = collate_mod._ensure_rgb(conversations)
1152+
assert result is conversations
1153+
1154+
1155+
class TestDefaultCollateFnEnsureRgb:
1156+
"""Test that default_collate_fn integrates _ensure_rgb correctly."""
1157+
1158+
def test_rgba_image_converted_before_processing(self, collate_mod, fake_qwen_utils, monkeypatch):
1159+
monkeypatch.setattr(collate_mod, "HAVE_QWEN_VL_UTILS", True, raising=True)
1160+
1161+
captured_conversations = []
1162+
1163+
class CapturingProcessor:
1164+
tokenizer = DummyTokenizer()
1165+
1166+
def apply_chat_template(self, conv_list, **kwargs):
1167+
for conv in conv_list:
1168+
for turn in conv:
1169+
content = turn.get("content")
1170+
if isinstance(content, list):
1171+
for item in content:
1172+
if isinstance(item, dict) and isinstance(item.get("image"), PILImage.Image):
1173+
captured_conversations.append(item["image"].mode)
1174+
batch_size = len(conv_list)
1175+
input_ids = torch.arange(1, 5).unsqueeze(0).repeat(batch_size, 1)
1176+
pixel_values = torch.ones(batch_size, 3, 64, 64, dtype=torch.float32)
1177+
return {"input_ids": input_ids, "pixel_values": pixel_values}
1178+
1179+
rgba_img = PILImage.new("RGBA", (4, 4), (255, 0, 0, 128))
1180+
conversation = [
1181+
{"role": "user", "content": [{"image": rgba_img}, {"type": "text", "text": "describe"}]},
1182+
{"role": "assistant", "content": [{"type": "text", "text": "red"}]},
1183+
]
1184+
1185+
processor = CapturingProcessor()
1186+
collate_mod.default_collate_fn([{"conversation": conversation}], processor)
1187+
1188+
assert captured_conversations == ["RGB"]

0 commit comments

Comments
 (0)