diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 8dd4098a..17577982 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -28,6 +28,7 @@ class GPTSamplingParameters(SamplingParameters): # TODO: ====== Get these to memmap dataset (currently ignored) ====== use_loss_masking_spans: bool = False use_preference_loss_spans: bool = False + use_images: bool = False @dataclasses.dataclass(kw_only=True) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 9bf29203..0fa945ab 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -5,6 +5,7 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.data.preparator.config import DatasetPreparatorConfig +from fast_llm.data.preprocessing.image_patch import ImagePatchConfig from fast_llm.data.preprocessing.tokenizer import TokenizerConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.runnable import RunnableConfig @@ -35,6 +36,10 @@ class LanguageModelSourceConfig(Config): rejected_span: None | str = Field( default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional ) + images: None | str = Field(default=None, desc="Field containing images", hint=FieldHint.optional) + image_positions: None | str = Field( + default=None, desc="Field containing image positions in the text.", hint=FieldHint.optional + ) @functools.cached_property def columns(self) -> list[str]: @@ -54,6 +59,11 @@ def has_preference_spans(self) -> bool: Assert.eq(self.chosen_span is None, self.rejected_span is None) return self.chosen_span is not None + @functools.cached_property + def has_images(self) -> bool: + Assert.eq(self.images is None, self.image_positions is None) + return self.images is not None + def _validate(self): super()._validate() if self.has_preference_spans and self.has_loss_masking_span: @@ -177,6 +187,10 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) + image_patches: ImagePatchConfig = Field( + desc="Configuration for the image patches, if enabled.", + hint=FieldHint.feature, + ) splits: dict[str, float] | None = Field( default=None, desc="Split the output dataset into multiple ones (ex, train/valid/test) with the specified ratios." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 18d4d46e..0d5c0178 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -30,6 +30,7 @@ from fast_llm.data.preprocessing.tokenizer import Tokenizer from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter +from fast_llm.data.sample.patch import PatchSample from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type @@ -43,6 +44,7 @@ class SpanType(enum.StrEnum): loss_masking = "loss_masking" chosen = "chosen" rejected = "rejected" + image = "image" class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]): @@ -231,6 +233,30 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: text = full_chosen_text + full_rejected_text all_spans.extend(chosen_spans + rejected_span) + if self._source_schema.has_images: + # Get the images and positions, sorted by position. + images, image_positions = ( + zip( + *sorted( + zip( + sample[self._source_schema.images], + sample[self._source_schema.image_positions], + strict=True, + ), + key=lambda x: x[1], + ) + ) + if len(sample[self._source_schema.images]) > 0 + else ([], []) + ) + # Get the image patches and associated data. + image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = ( + self._config.image_patches.get_patches(images, self._data_type) + ) + patch_count_cumsum = padded_cumsum(patch_counts).tolist() + # Add an empty "span" at each image position so we know where to insert them in the tokenized sequence. + all_spans.extend([(SpanType.image, (position, position)) for position in image_positions]) + # Sort the spans by location (begin), keeping track of their type. # Note: overlapping spans are not supported (explicit assertion in the tokenizer). span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], []) @@ -241,8 +267,26 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: # Gather token spans by type. token_spans_by_type = collections.defaultdict(list) - for span_type, token_span in zip(span_types, token_spans, strict=True): - token_spans_by_type[span_type].append(token_span) + if self._source_schema.has_images: + # Insert the image token ids in the token sequence and shift the spans accordingly. + tokens_shift = 0 + image_index = 0 + for span_type, (begin, end) in zip(span_types, token_spans, strict=True): + # Account for the tokens already inserted. + begin = begin + tokens_shift + end = end + tokens_shift + if span_type == SpanType.image: + # Shift the token map to the image location. + image_token_maps[patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]] += begin + # Insert the placeholder and image break tokens. + tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]]) + tokens_shift += len(image_token_ids[image_index]) + image_index += 1 + else: + token_spans_by_type[span_type].append((begin, end)) + else: + for span_type, token_span in zip(span_types, token_spans, strict=True): + token_spans_by_type[span_type].append(token_span) sample_size = len(tokens) @@ -264,6 +308,11 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: if self._source_schema.has_preference_spans else None ), + ( + PatchSample(image_patches, image_token_maps, image_position_ids, sample_size, patch_counts) + if self._source_schema.has_images + else None + ), ) def generate_config_yaml_for_sharded_dst( diff --git a/fast_llm/data/preprocessing/image_patch.py b/fast_llm/data/preprocessing/image_patch.py new file mode 100644 index 00000000..61e5dd7b --- /dev/null +++ b/fast_llm/data/preprocessing/image_patch.py @@ -0,0 +1,185 @@ +import functools +import io +import math +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.utils import Assert, div + +if typing.TYPE_CHECKING: + import torch + + +@config_class() +class ImagePatchConfig(Config): + """ + Configuration for the tokenizer. + The tokenizer is needed for FIM and dataset preparation. + """ + + height: int = Field( + default=16, + desc="Height of the image patches, in pixels.", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + width: int = Field( + default=16, + desc="Height of the image patches, in pixels.", + hint=FieldHint.core, + valid=check_field(Assert.gt, 0), + ) + max_image_height: int = Field( + default=1024, + desc="Maximum height of the complete image, in pixels." + "If the original image is larger than this, it will be resized to this height.", + hint=FieldHint.optional, + ) + max_image_width: int = Field( + default=1024, + desc="Maximum width of the complete image, in pixels." + "If the original image is larger than this, it will be resized to this width.", + hint=FieldHint.optional, + ) + image_break_token: int | None = Field( + default=None, + desc="Add this token at the end of each row of image patches.", + hint=FieldHint.optional, + ) + image_end_token: int | None = Field( + default=None, + desc="Add this token after the last patch of each image." + "If `image_break_token` is also defined, only `image_end_token` is added after the last row.", + hint=FieldHint.optional, + ) + + @property + def num_channels(self) -> int: + # assume 3 channels (RGB) for all images + return 3 + + @functools.cached_property + def max_patches_height(self) -> int: + return div(self.max_image_height, self.height) + + @functools.cached_property + def max_patches_width(self) -> int: + return div(self.max_image_width, self.width) + + def _validate(self): + super()._validate() + Assert.gt(self.max_patches_height, 0) + Assert.gt(self.max_patches_width, 0) + + def get_patches( + self, images: list[bytes], token_data_type: DataType = DataType.int64 + ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", list["torch.Tensor"], list[int]]: + import torch + + if len(images) > 0: + image_patches, image_positions, image_token_maps, image_token_ids = zip( + *(self._get_patches(image, token_data_type) for image in images) + ) + return ( + torch.cat(image_patches), + torch.cat(image_positions), + torch.cat(image_token_maps), + image_token_ids, + [len(position_ids) for position_ids in image_positions], + ) + else: + # Return empty tensors of appropriate shapes and data types so we can concatenate with other documents. + return ( + torch.empty(0, self.num_channels, self.height, self.width, dtype=torch.uint8), + torch.empty(0, 2, dtype=torch.int64), + torch.empty(0, dtype=torch.int64), + [], + [0], + ) + + def _get_patches( + self, image_bytes: bytes, token_data_type: DataType = DataType.int64 + ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: + import numpy as np + import PIL.Image + import torch + + with PIL.Image.open(io.BytesIO(image_bytes)) as image: + if image.mode != "RGB": + # Convert all images to RGB + image = image.convert("RGB") + image = torch.tensor(np.array(image)).permute(2, 0, 1) # HWC to CHW + Assert.eq(image.dtype, torch.uint8) + + # Resize to a multiple of patch size smaller or equal to max size. + image = self._resize(image) + num_patches_height = div(image.size(1), self.height) + num_patches_width = div(image.size(2), self.width) + # Convert to patches. (`torch.nn.functional.unfold` not supported for uint8.) + patches = ( + image.view(self.num_channels, num_patches_height, self.height, num_patches_width, self.width) + .permute(3, 1, 0, 2, 4) + .flatten(0, 1) + ) + + positions = torch.stack( + [ + torch.arange(num_patches_height).repeat_interleave(num_patches_width), + torch.arange(num_patches_width).repeat(num_patches_height), + ], + 1, + ) + + token_map = torch.arange(0, num_patches_width * num_patches_height, dtype=torch.int64) + if self.image_break_token is None: + token_ids = [-100] * (num_patches_width * num_patches_height) + if self.image_end_token is not None: + token_ids.append(self.image_end_token) + else: + token_ids = ([-100] * num_patches_width + [self.image_break_token]) * num_patches_height + token_map += torch.arange(num_patches_height).repeat_interleave(num_patches_width) + if self.image_end_token is not None: + token_ids[-1] = self.image_end_token + + return patches, positions, token_map, torch.tensor(token_ids, dtype=token_data_type.torch) + + def _resize(self, image: "torch.Tensor") -> "torch.Tensor": + """ + Calculate the new dimensions for resizing an image while maintaining the aspect ratio. + If the image is larger than the max dimensions, it will be resized to fit within them. + If the image is smaller, it will be resized to the nearest multiple of the patch size. + """ + import torchvision.transforms.v2 as torchvision_transforms + + target_height, target_width = image.shape[1:] + ratio = max(target_height / self.max_image_height, target_width / self.max_image_width, 1) + target_height = self.height * math.ceil(target_height / self.height / ratio) + target_width = self.width * math.ceil(target_width / self.width / ratio) + + # Cap the resizing to half of the current size as a workaround for large images + # See pytorch issue: https://github.com/pytorch/pytorch/issues/103589 + while max(image.size(1) / target_height, image.size(2) / target_width) > 2: + image = torchvision_transforms.functional.resize( + image, + size=(math.ceil(image.size(1) / 2), math.ceil(image.size(2) / 2)), + interpolation=torchvision_transforms.InterpolationMode.BICUBIC, + ) + + # TODO: options for interpolation mode? + return torchvision_transforms.functional.resize( + image, size=(target_height, target_width), interpolation=torchvision_transforms.InterpolationMode.BICUBIC + ) + + +@config_class() +class ImageNormalizationConfig(Config): + scale: float = Field(default=255.0) + # Default values from OpenAI Clip. + mean: tuple[float, float, float] = Field(default=(0.48145466, 0.4578275, 0.40821073)) + std: tuple[float, float, float] = Field(default=(0.26862954, 0.26130258, 0.27577711)) + + def normalize(self, image: "torch.Tensor") -> "torch.Tensor": + import torchvision.transforms.v2 as torchvision_transforms + + return torchvision_transforms.functional.normalize(image / self.scale, list(self.mean), list(self.std)) diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 6f485bf8..88ca05b9 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -6,6 +6,7 @@ import torch from fast_llm.config import Field, config_class +from fast_llm.data.preprocessing.image_patch import ImageNormalizationConfig from fast_llm.data.sample.abstract import ( Batch, MemmapIndexDatasetReaderConfig, @@ -15,6 +16,7 @@ NullReaderConfig, Sample, ) +from fast_llm.data.sample.patch import PatchBatch, PatchSample, PatchWriter from fast_llm.data.sample.range import RangeBatch, RangeSample, RangeWriter from fast_llm.data.sample.token import TokenBatch, TokenReaderConfig, TokenSample, TokenWriter from fast_llm.utils import Assert @@ -27,11 +29,13 @@ def __init__( loss_masking_spans: RangeSample | None = None, chosen_spans: RangeSample | None = None, rejected_spans: RangeSample | None = None, + image_patches: PatchSample | None = None, ): self.tokens = tokens self.loss_masking_spans = loss_masking_spans self.chosen_spans = chosen_spans self.rejected_spans = rejected_spans + self.image_patches = image_patches @classmethod def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: @@ -40,6 +44,7 @@ def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: _merge_optional(RangeSample.from_documents, [document.loss_masking_spans for document in documents]), _merge_optional(RangeSample.from_documents, [document.chosen_spans for document in documents]), _merge_optional(RangeSample.from_documents, [document.rejected_spans for document in documents]), + _merge_optional(PatchSample.from_documents, [document.image_patches for document in documents]), ) def crop(self, begin: int, end: int) -> typing.Self: @@ -48,6 +53,7 @@ def crop(self, begin: int, end: int) -> typing.Self: _crop_optional(self.loss_masking_spans, begin, end), _crop_optional(self.chosen_spans, begin, end), _crop_optional(self.rejected_spans, begin, end), + _crop_optional(self.image_patches, begin, end), ) def __len__(self) -> int: @@ -59,6 +65,7 @@ def get_padding(self, size: int) -> typing.Self: None if self.loss_masking_spans is None else self.loss_masking_spans.get_padding(size), None if self.chosen_spans is None else self.chosen_spans.get_padding(size), None if self.rejected_spans is None else self.rejected_spans.get_padding(size), + None if self.image_patches is None else self.image_patches.get_padding(size), ) @@ -69,11 +76,13 @@ def __init__( loss_masking_spans: RangeBatch | None = None, chosen_spans: RangeBatch | None = None, rejected_spans: RangeBatch | None = None, + image_patches: PatchBatch | None = None, ): self.tokens = tokens self.loss_masking_spans = loss_masking_spans self.chosen_spans = chosen_spans self.rejected_spans = rejected_spans + self.image_patches = image_patches @classmethod def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.Self: @@ -82,26 +91,16 @@ def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.S _merge_optional(RangeBatch.from_samples, [sample.loss_masking_spans for sample in samples]), _merge_optional(RangeBatch.from_samples, [sample.chosen_spans for sample in samples]), _merge_optional(RangeBatch.from_samples, [sample.rejected_spans for sample in samples]), + _merge_optional(PatchBatch.from_samples, [sample.image_patches for sample in samples]), ) - def to_samples(self) -> list[LanguageModelSample]: - return [ - LanguageModelSample(tokens, loss_masking_spans, chosen_spans, rejected_spans) - for tokens, loss_masking_spans, chosen_spans, rejected_spans in zip( - self.tokens.to_samples(), - None if self.loss_masking_spans is None else self.loss_masking_spans.to_samples(), - None if self.chosen_spans is None else self.chosen_spans.to_samples(), - None if self.rejected_spans is None else self.rejected_spans.to_samples(), - strict=True, - ) - ] - def crop(self, begin: int, end: int) -> typing.Self: return self.__class__( self.tokens.crop(begin, end), _crop_optional(self.loss_masking_spans, begin, end), _crop_optional(self.chosen_spans, begin, end), _crop_optional(self.rejected_spans, begin, end), + _crop_optional(self.image_patches, begin, end), ) def to_device_(self, device: "torch.device | str"): @@ -112,6 +111,8 @@ def to_device_(self, device: "torch.device | str"): self.chosen_spans.to_device_(device) if self.rejected_spans is not None: self.rejected_spans.to_device_(device) + if self.image_patches is not None: + self.image_patches.to_device_(device) def _merge_optional[T](fn: typing.Callable[[typing.Iterable], T], args: typing.Iterable) -> T | None: @@ -132,6 +133,7 @@ class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): loss_masking_spans: MemmapReaderBaseConfig = Field() chosen_spans: MemmapReaderBaseConfig = Field() rejected_spans: MemmapReaderBaseConfig = Field() + image_patches: MemmapReaderBaseConfig = Field() def __len__(self) -> int: return len(self.tokens) @@ -155,6 +157,7 @@ def _expected_buffer_size(self) -> int: + self.loss_masking_spans.expected_buffer_size + self.chosen_spans.expected_buffer_size + self.rejected_spans.expected_buffer_size + + self.image_patches.expected_buffer_size ) @@ -166,17 +169,28 @@ def __init__(self, config: ConfigType, buffer: memoryview): self._loss_masking_spans = self._config.loss_masking_spans.get_reader(buffer) self._chosen_spans = self._config.chosen_spans.get_reader(buffer) self._rejected_spans = self._config.rejected_spans.get_reader(buffer) + self._image_patches = self._config.image_patches.get_reader(buffer) + + if self._image_patches is not None: + # TODO: Make this configurable. + self._image_normalization_config = ImageNormalizationConfig() @property def num_tokens(self) -> int: return self._config.tokens.num_tokens def get_document(self, index: int, begin: int, end: int) -> Sample: + if self._image_patches is None: + image_patches = None + else: + image_patches = self._image_patches.get_document(index, begin, end) + image_patches.patches = self._image_normalization_config.normalize(image_patches.patches) return LanguageModelSample( self._tokens.get_document(index, begin, end), None if self._loss_masking_spans is None else self._loss_masking_spans.get_document(index, begin, end), None if self._chosen_spans is None else self._chosen_spans.get_document(index, begin, end), None if self._rejected_spans is None else self._rejected_spans.get_document(index, begin, end), + image_patches, ) def get_document_sizes(self) -> torch.Tensor: @@ -189,6 +203,7 @@ def get_document_size(self, index: int) -> int: class LanguageModelWriter(MemmapWriter): _has_loss_masking_spans: bool | None = None _has_preference_spans: bool | None = None + _has_image_patches: bool | None = None def __enter__(self): super().__enter__() @@ -202,6 +217,7 @@ def __enter__(self): self._loss_masking_span_writer = RangeWriter(self._path.joinpath("loss_masking_spans")).__enter__() self._chosen_spans_writer = RangeWriter(self._path.joinpath("chosen_spans")).__enter__() self._rejected_spans_writer = RangeWriter(self._path.joinpath("rejected_spans")).__enter__() + self._image_patches_writer = PatchWriter(self._path.joinpath("image_patches")).__enter__() return self def write(self, document: LanguageModelSample): @@ -231,11 +247,22 @@ def write(self, document: LanguageModelSample): self._chosen_spans_writer.write(document.chosen_spans) self._rejected_spans_writer.write(document.rejected_spans) + # Ensure either all samples have image patches or none of them do. + if self._has_image_patches is None: + self._has_image_patches = document.image_patches is not None + else: + Assert.eq(self._has_image_patches, document.image_patches is not None) + + # Write image patches + if self._has_image_patches: + self._image_patches_writer.write(document.image_patches) + def __exit__(self, exc_type, exc_val, exc_tb): self._token_writer.__exit__(exc_type, exc_val, exc_tb) self._loss_masking_span_writer.__exit__(exc_type, exc_val, exc_tb) self._chosen_spans_writer.__exit__(exc_type, exc_val, exc_tb) self._rejected_spans_writer.__exit__(exc_type, exc_val, exc_tb) + self._image_patches_writer.__exit__(exc_type, exc_val, exc_tb) if exc_type is None: # A dummy config so we can verify the begin and end offsets. @@ -263,6 +290,14 @@ def __exit__(self, exc_type, exc_val, exc_tb): config.rejected_spans.end, ) + if self._has_image_patches: + _copy_chunked( + self._path.joinpath("image_patches"), + self._stream, + config.image_patches.begin, + config.image_patches.end, + ) + self._directory.cleanup() super().__exit__(exc_type, exc_val, exc_tb) @@ -286,6 +321,11 @@ def _get_config(self, begin: int, end: int | None): else: chosen_spans = NullReaderConfig() rejected_spans = NullReaderConfig() + if self._has_image_patches: + image_patches = self._image_patches_writer.get_config(offset) + offset = image_patches.end + else: + image_patches = NullReaderConfig() if end is None: end = offset + len(LanguageModelReaderConfig.footer) @@ -297,6 +337,7 @@ def _get_config(self, begin: int, end: int | None): loss_masking_spans=loss_masking_spans, chosen_spans=chosen_spans, rejected_spans=rejected_spans, + image_patches=image_patches, ) diff --git a/fast_llm/data/sample/patch.py b/fast_llm/data/sample/patch.py new file mode 100644 index 00000000..dd7c9850 --- /dev/null +++ b/fast_llm/data/sample/patch.py @@ -0,0 +1,303 @@ +import math +import typing + +import numpy as np +import torch + +from fast_llm.config import Field, config_class +from fast_llm.data.sample.abstract import ( + Batch, + MemmapReader, + MemmapReaderBaseConfig, + MemmapReaderConfig, + MemmapWriter, + Sample, +) +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.utils import Assert, get_unique, padded_cumsum + + +def filter_lengths(lengths: list[int], filter: torch.Tensor) -> list[int]: + length_cumsum = padded_cumsum(lengths) + filtered_lengths = (filter[begin:end].sum().item() for begin, end in zip(length_cumsum[:-1], length_cumsum[1:])) + return [length for length in filtered_lengths if length > 0] + + +class PatchSample(Sample): + """ + A reusable component holding a set of fixed-shape patches (ex. images, audio, video), + each of which providing a single token embedding in a multimodal model. + """ + + def __init__( + self, + patches: torch.Tensor, + token_map: torch.Tensor, + positions: torch.Tensor, + sample_size: int, + lengths: list[int] | None = None, + ): + # Tensor of dimensions (patch, *patch_shape) + self.patches = patches + # Mapping from patch to token index + self.token_map = token_map + # A position identifier for each patch in the patch grid. + Assert.eq(positions.shape, (self.patches.size(0), self.patches.ndim - 2)) + self.positions = positions + # Number of tokens in the sample (not the number of patches) + self.sample_size = sample_size + # Length of each patch group (ex. image) in the sample. TODO: Use cumsums instead? + if lengths is None: + lengths = [len(patches)] + else: + Assert.eq(sum(lengths), len(patches)) + self.lengths = lengths + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + total_size = 0 + embedding_maps = [] + for document in documents: + embedding_maps.append(document.token_map + total_size) + total_size += document.sample_size + return cls( + torch.cat([document.patches for document in documents]), + torch.cat(embedding_maps), + torch.cat([document.positions for document in documents]), + total_size, + sum((document.lengths for document in documents), []), + ) + + def crop(self, begin: int, end: int) -> typing.Self: + sample_size = end - begin + patch_filter = (self.token_map >= begin) & (self.token_map < end) + return self.__class__( + self.patches[patch_filter], + self.token_map[patch_filter] - begin, + self.positions[patch_filter], + sample_size, + filter_lengths(self.lengths, patch_filter), + ) + + def __len__(self) -> int: + return self.sample_size + + def get_padding(self, size: int) -> typing.Self: + return PatchSample( + self.patches.new_empty((0, *self.patches.shape[1:])), + self.token_map.new_empty(0), + self.positions.new_empty(0), + size, + [], + ) + + +class PatchBatch(Batch): + def __init__( + self, + patches: torch.Tensor, + sample_map: torch.Tensor, + token_map: torch.Tensor, + positions: torch.Tensor, + num_samples: int, + sample_size: int, + lengths: list[int], + ): + # Concatenated along patch index rather than stacked since the lengths are not constant + self.patches = patches + # Mapping from patch to sample index + self.sample_map = sample_map + self.token_map = token_map + self.positions = positions + self.num_samples = num_samples + self.sample_size = sample_size + self.lengths = lengths + + @classmethod + def from_samples(cls, samples: typing.Sequence[PatchSample]) -> typing.Self: + return cls( + torch.cat([sample.patches for sample in samples]), + torch.cat( + [torch.full_like(sample.token_map, sample_index) for sample_index, sample in enumerate(samples)] + ), + torch.cat([sample.token_map for sample in samples]), + torch.cat([sample.positions for sample in samples]), + len(samples), + get_unique(sample.sample_size for sample in samples), + [length for sample in samples for length in sample.lengths], + ) + + def crop(self, begin: int, end: int) -> typing.Self: + sample_size = end - begin + patch_filter = (self.token_map >= begin) & (self.token_map < end) + + return self.__class__( + self.patches[patch_filter], + self.sample_map[patch_filter], + self.token_map[patch_filter], + self.positions[patch_filter], + self.num_samples, + sample_size, + filter_lengths(self.lengths, patch_filter), + ) + + def to_device_(self, device: "torch.device | str"): + self.patches = self.patches.to(device, non_blocking=True) + self.sample_map = self.sample_map.to(device, non_blocking=True) + self.token_map = self.token_map.to(device, non_blocking=True) + self.positions = self.positions.to(device, non_blocking=True) + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "patch"}) +class PatchReaderConfig(MemmapReaderConfig): + _abstract = False + header: typing.ClassVar[bytes] = b"patch begin" + footer: typing.ClassVar[bytes] = b"patch end" + num_documents: int = Field() + num_patches: int = Field() + num_patch_groups: int = Field() + patch_shape: tuple[int, ...] = Field() + data_type: DataType = Field() + + def __len__(self) -> int: + return self.num_documents + + @property + def reader_class(self) -> "type[PatchReader]": + return PatchReader + + @property + def writer_class(self) -> "type[PatchWriter]": + return PatchWriter + + @property + def patch_size(self) -> int: + return math.prod(self.patch_shape) + + @property + def grid_dims(self) -> int: + return len(self.patch_shape) - 1 + + @property + def _expected_buffer_size(self) -> int: + return ( + self.num_patches * self.patch_size * self.data_type.torch.itemsize + + ((1 + self.grid_dims) * self.num_patches + self.num_patch_groups + 2 * self.num_documents + 2) + * torch.int32.itemsize + ) + + +class PatchReader[ConfigType: PatchReaderConfig](MemmapReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config, buffer) + self._patches = torch.frombuffer( + self._buffer, + dtype=self._config.data_type.torch, + count=self._config.num_patches * self._config.patch_size, + ).view(self._config.num_patches, *self._config.patch_shape) + offset = self._patches.nbytes + self._token_map = torch.frombuffer( + self._buffer, + dtype=torch.int32, + count=self._config.num_patches, + offset=offset, + ) + offset += self._token_map.nbytes + self._positions = torch.frombuffer( + self._buffer, + dtype=torch.int32, + count=self._config.num_patches * self._config.grid_dims, + offset=offset, + ).view(self._config.num_patches, self._config.grid_dims) + offset += self._positions.nbytes + self._patch_count_cumsums = torch.frombuffer( + self._buffer, + dtype=torch.int32, + count=self._config.num_documents + 1, + offset=offset, + ) + offset += self._patch_count_cumsums.nbytes + self._group_lengths = torch.frombuffer( + self._buffer, + dtype=torch.int32, + count=self._config.num_patch_groups, + offset=offset, + ) + offset += self._group_lengths.nbytes + self._group_count_cumsums = torch.frombuffer( + self._buffer, + dtype=torch.int32, + count=self._config.num_documents + 1, + offset=offset, + ) + + def get_document(self, index: int, begin: int, end: int) -> Sample: + token_map = self._token_map[ + token_slice := slice(self._patch_count_cumsums[index], self._patch_count_cumsums[index + 1]) + ] + patch_filter = (token_map >= begin) & (token_map < end) + return PatchSample( + self._patches[token_slice][patch_filter], + token_map[patch_filter] - begin, + self._positions[token_slice][patch_filter], + end - begin, + filter_lengths( + self._group_lengths[self._group_count_cumsums[index] : self._group_count_cumsums[index + 1]].tolist(), + patch_filter, + ), + ) + + +class PatchWriter(MemmapWriter): + def __enter__(self): + super().__enter__() + self._patch_count_cumsum = [0] + self._group_count_cumsum = [0] + self._token_map = [] + self._positions = [] + self._group_lengths = [] + self._data_type = None + self._patch_shape = None + return self + + def write(self, document: PatchSample): + super().write(document) + if self._data_type is None: + self._data_type = document.patches.dtype + else: + Assert.eq(self._data_type, document.patches.dtype) + if self._patch_shape is None: + self._patch_shape = tuple(document.patches.shape[1:]) + else: + Assert.eq(self._patch_shape, document.patches.shape[1:]) + self._stream.write(document.patches.numpy().tobytes()) + self._token_map.extend(document.token_map) + self._positions.extend(document.positions) + self._patch_count_cumsum.append(self._patch_count_cumsum[-1] + len(document.patches)) + self._group_count_cumsum.append(self._group_count_cumsum[-1] + len(document.lengths)) + self._group_lengths.extend(document.lengths) + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + Assert.lt(self._patch_count_cumsum[-1], np.iinfo(np.int32).max) + self._stream.write(np.array(self._token_map, dtype=np.int32).tobytes(order="C")) + self._stream.write(np.array(self._positions, dtype=np.int32).tobytes(order="C")) + self._stream.write(np.array(self._patch_count_cumsum, dtype=np.int32).tobytes(order="C")) + self._stream.write(np.array(self._group_lengths, dtype=np.int32).tobytes(order="C")) + self._stream.write(np.array(self._group_count_cumsum, dtype=np.int32).tobytes(order="C")) + super().__exit__(exc_type, exc_val, exc_tb) + + @classmethod + def _get_config_class(cls) -> type[PatchReaderConfig]: + return PatchReaderConfig + + def _get_config(self, begin: int, end: int): + return PatchReaderConfig( + begin=begin, + end=end, + num_documents=len(self._patch_count_cumsum) - 1, + num_patches=self._patch_count_cumsum[-1], + num_patch_groups=self._group_count_cumsum[-1], + patch_shape=self._patch_shape, + data_type=DataType.from_torch(self._data_type), + ) diff --git a/tests/data/test_image_patch.py b/tests/data/test_image_patch.py new file mode 100644 index 00000000..197d1db2 --- /dev/null +++ b/tests/data/test_image_patch.py @@ -0,0 +1,168 @@ +import hashlib +import io + +import datasets +import numpy as np +import PIL.Image +import pytest + +from fast_llm.data.dataset.gpt.config import GPTDatasetFromFileConfig, GPTSamplingParameters +from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.utils import Assert +from tests.data.common import get_dataset_config +from tests.data.test_preparator import COMMON_DATASET_LENGTH, COMMON_DATASET_TEXT +from tests.utils.dataset import get_test_dataset_with_image_patches + +DATASET_WITH_IMAGE_PATCHES_TOKENS = [55750, 56809, 59145, 59145] +DATASET_WITH_IMAGE_PATCHES_IMAGE_MD5 = { + 27: [], + 30: ["a2c34e404506fe664efcdb520642f260"], + 31: ["3aca101f63e09f75b070fcc53ee44895", "0e027a3d45b34767a0cf5c67280dc825"], + 77: [], + 87: ["65597f534c1e2a257ac1b5282100d541", "612404822bc61a0b2a293890e7246621"], +} +DATASET_WITH_IMAGE_PATCHES_IMAGE_POSITIONS = { + 27: [], + 30: [3], + 31: [2, 4], + 77: [], + 87: [1, 2], +} +DATASET_WITH_IMAGE_PATCHES_IMAGE_SHAPES = { + 27: [], + 30: [(30, 4)], + 31: [(7, 22), (14, 24)], + 77: [], + 87: [(17, 4), (15, 12)], +} +DATASET_WITH_IMAGE_PATCHES_SAMPLES = { + 27: [49152, 63, 82, 11, 27799, 49152], + 30: [49152, 31, 2327, (4, 1), 27, 1448, 62, 43, 49152], + 31: [49152, 60, 55, (2, 4), 80, 30, (3, 4), 85, 22, 18, 49152], + 77: [49152, 13736, 85, 52, 22, 46, 5, 11807, 49152], + 87: [49152, 52, (4, 1), 89, (4, 3), 75, 11, 71, 49152], +} + + +def _shifted_range(begin: int, height_patches: int, width_patches: int, shift: int = 1): + return [ + i + for row in range(height_patches) + for i in range(begin + row * (width_patches + shift), begin + row * (width_patches + shift) + width_patches) + ] + + +DATASET_WITH_IMAGE_PATCHES_TOKEN_MAP = { + 27: [[] for _ in range(4)], + 30: [ + list(range(3, 7)), + list(range(3, 7)), + _shifted_range(3, 4, 1), + _shifted_range(3, 4, 1), + ], + 31: [ + [*range(3, 11), *range(13, 25)], + [*range(3, 11), *range(14, 26)], + _shifted_range(3, 2, 4) + _shifted_range(15, 3, 4), + _shifted_range(3, 2, 4) + _shifted_range(15, 3, 4), + ], + 77: [[] for _ in range(4)], + 87: [ + [*range(2, 6), *range(7, 19)], + [*range(2, 6), *range(8, 20)], + _shifted_range(2, 4, 1) + _shifted_range(11, 4, 3), + _shifted_range(2, 4, 1) + _shifted_range(11, 4, 3), + ], +} + + +def _position_ids(height_patches: int, width_patches: int): + return [[i, j] for i in range(height_patches) for j in range(width_patches)] + + +DATASET_WITH_IMAGE_PATCHES_POSITIONS = { + 27: [], + 30: _position_ids(4, 1), + 31: _position_ids(2, 4) + _position_ids(3, 4), + 77: [], + 87: _position_ids(4, 1) + _position_ids(4, 3), +} +DATASET_WITH_IMAGE_PATCHES_LENGTHS = { + 27: [], + 30: [4], + 31: [8, 12], + 77: [], + 87: [4, 12], +} +DATASET_WITH_IMAGE_PATCHES_PATCHES_MD5 = { + 27: "d41d8cd98f00b204e9800998ecf8427e", + 30: "f9e5a216990b1a3646677195532dddec", + 31: "c56ce50e02154d52e82d320547e3973f", + 77: "d41d8cd98f00b204e9800998ecf8427e", + 87: "90ab851ceb87678b4c151edee2049702", +} + + +def _get_image_tokens( + height_patches: int, width_patches: int, image_break_token: int | None, image_end_token: int | None +): + return ([-100] * width_patches + ([] if image_break_token is None else [image_break_token])) * ( + height_patches - 1 + ) + ( + [-100] * width_patches + + ( + [image_end_token] + if image_end_token is not None + else [] if image_break_token is None else [image_break_token] + ) + ) + + +@pytest.mark.slow +@pytest.mark.parametrize("image_break_token", (None, 55)) +@pytest.mark.parametrize("image_end_token", (None, 132)) +def test_gpt_data_with_image_patches(image_break_token, image_end_token): + _, config, hf_path = get_test_dataset_with_image_patches(image_break_token, image_end_token) + dataset: MemmapDataset[LanguageModelSample] = get_dataset_config(config, GPTDatasetFromFileConfig).build() + test_index = 2 * (image_break_token is not None) + (image_end_token is not None) + + hf_dataset = datasets.load_from_disk(hf_path)["train"] + + # Check global stats. + Assert.eq(len(dataset), len(hf_dataset), COMMON_DATASET_LENGTH) + Assert.eq(dataset.num_tokens, DATASET_WITH_IMAGE_PATCHES_TOKENS[test_index]) + + # Check some numerical values. + for index in DATASET_WITH_IMAGE_PATCHES_SAMPLES: + Assert.eq(hf_dataset[index]["text"], COMMON_DATASET_TEXT[index]) + Assert.eq( + [hashlib.md5(image).hexdigest() for image in hf_dataset[index]["images"]], + DATASET_WITH_IMAGE_PATCHES_IMAGE_MD5[index], + ) + Assert.eq( + [np.array(PIL.Image.open(io.BytesIO(image))).shape[:2] for image in hf_dataset[index]["images"]], + DATASET_WITH_IMAGE_PATCHES_IMAGE_SHAPES[index], + ) + Assert.eq(hf_dataset[index]["image_positions"], DATASET_WITH_IMAGE_PATCHES_IMAGE_POSITIONS[index]) + + document = dataset.get_document( + index, parameters=GPTSamplingParameters(num_samples=0, sequence_length=0, use_images=True) + ) + expected_tokens = [ + tokens + for token_or_patches in DATASET_WITH_IMAGE_PATCHES_SAMPLES[index] + for tokens in ( + _get_image_tokens(*token_or_patches, image_break_token, image_end_token) + if isinstance(token_or_patches, tuple) + else [token_or_patches] + ) + ] + Assert.eq(document.tokens.tokens.tolist(), expected_tokens) + Assert.eq(document.image_patches.token_map.tolist(), DATASET_WITH_IMAGE_PATCHES_TOKEN_MAP[index][test_index]) + Assert.eq(document.image_patches.positions.tolist(), DATASET_WITH_IMAGE_PATCHES_POSITIONS[index]) + Assert.eq(document.image_patches.lengths, DATASET_WITH_IMAGE_PATCHES_LENGTHS[index]) + Assert.eq( + hashlib.md5(document.image_patches.patches.numpy().tobytes()).hexdigest(), + DATASET_WITH_IMAGE_PATCHES_PATCHES_MD5[index], + ) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 28d28bd9..b88f834c 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -1,10 +1,13 @@ +import io import pathlib import typing import datasets import numpy as np +import PIL.Image from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig +from fast_llm.data.preprocessing.image_patch import ImagePatchConfig from fast_llm.utils import padded_cumsum from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_FILE, TOKENIZER_PATH @@ -71,6 +74,51 @@ def get_random_preference_spans(texts, random_state: np.random.RandomState = np. return {"text": texts_, "chosen_span": chosen_spans, "rejected_span": rejected_spans} +def _save_image_to_bytes(image: np.ndarray, format="PNG", mode="RGB"): + buffer = io.BytesIO() + PIL.Image.fromarray(image, mode).save(buffer, format=format) + return buffer.getvalue() + + +def get_random_images( + document_sizes: np.ndarray, + min_images: int, + max_images: int, + min_image_size: int, + max_image_size: int, + random_state: np.random.RandomState = np.random, +): + # Randomize image count for each sample. + image_counts = random_state.randint(min_images, max_images + 1, num_documents := len(document_sizes)) + image_count_cumsums = padded_cumsum(image_counts) + num_images = image_count_cumsums[-1] + # Randomize image shapes. + image_shapes = random_state.randint(min_image_size, max_image_size + 1, [num_images, 2]) + pixel_count_cumsum = padded_cumsum(image_shapes.prod(1) * 3) + # Generate random pixels. + pixels = random_state.randint(0, 256, pixel_count_cumsum[-1], dtype=np.uint8) + # Convert pixels to image byte buffers. + images = [ + _save_image_to_bytes( + pixels[pixel_count_cumsum[image_index] : pixel_count_cumsum[image_index + 1]].reshape( + [*image_shapes[image_index], 3] + ) + ) + for image_index in range(num_images) + ] + # Gather images by documents. + images = [ + images[image_count_cumsums[document_index] : image_count_cumsums[document_index + 1]] + for document_index in range(num_documents) + ] + # Generate random image positions. + image_positions = [ + np.sort(random_state.choice(range(document_size), image_counts[document_index], replace=False)).tolist() + for document_index, document_size in enumerate(document_sizes) + ] + return images, image_positions + + def _get_hf_test_dataset( seed: int = 1234, num_documents: int = 1000, @@ -79,6 +127,10 @@ def _get_hf_test_dataset( min_loss_masking_spans: int = 0, max_loss_masking_spans: int = 0, has_preference_spans: bool = False, + min_images: int = 0, + max_images: int = 0, + min_image_size: int = 4, + max_image_size: int = 32, ): random_state = np.random.RandomState(seed) # Generate random document sizes (character count). @@ -94,6 +146,11 @@ def _get_hf_test_dataset( document_sizes, min_loss_masking_spans, max_loss_masking_spans, random_state, use_last_format=True ) + if max_images > 0: + dataset_dict["images"], dataset_dict["image_positions"] = get_random_images( + document_sizes, min_images, max_images, min_image_size, max_image_size, random_state + ) + return datasets.Dataset.from_dict(dataset_dict) @@ -110,6 +167,11 @@ def _get_test_dataset( max_loss_masking_spans: int = 0, has_preference_spans: bool = False, splits: dict[str, float] | None = None, + min_images: int = 0, + max_images: int = 0, + image_patch_config: ImagePatchConfig | None = None, + min_image_size: int = 4, + max_image_size: int = 32, ): config_paths = ( [path / "fast_llm_config.yaml"] @@ -127,6 +189,10 @@ def _get_test_dataset( min_loss_masking_spans=min_loss_masking_spans, max_loss_masking_spans=max_loss_masking_spans, has_preference_spans=has_preference_spans, + min_images=min_images, + max_images=max_images, + min_image_size=min_image_size, + max_image_size=max_image_size, ) datasets.DatasetDict({"train": dataset}).save_to_disk(hf_path) source_schema = {"text": "text"} @@ -135,6 +201,9 @@ def _get_test_dataset( if has_preference_spans: source_schema["chosen_span"] = "chosen_span" source_schema["rejected_span"] = "rejected_span" + if max_images > 0: + source_schema["images"] = "images" + source_schema["image_positions"] = "image_positions" download_santacoder_tokenizer() preparator_config = GPTMemmapDatasetPreparatorConfig.from_dict( @@ -148,6 +217,7 @@ def _get_test_dataset( "output_path": path, "documents_per_shard": documents_per_shard, "splits": splits, + "image_patches": {} if image_patch_config is None else image_patch_config, } ) preparator_config.run() @@ -198,5 +268,21 @@ def get_test_dataset_with_preference_spans(): return _get_test_dataset(DATASET_CACHE / "dataset_with_preference_spans", seed=1234, has_preference_spans=True) +def get_test_dataset_with_image_patches(image_break_token: int | None = None, image_end_token: int | None = None): + return _get_test_dataset( + DATASET_CACHE / f"dataset_with_image_patches_{image_break_token}_{image_end_token}", + seed=1234, + max_images=2, + image_patch_config=ImagePatchConfig( + height=4, + width=4, + max_image_height=16, + max_image_width=16, + image_break_token=image_break_token, + image_end_token=image_end_token, + ), + ) + + def get_model_test_dataset(): return _get_test_dataset(DATASET_CACHE / "model_dataset", seed=1234, vocab_size=MODEL_TEST_VOCAB_SIZE)