Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class GPTSamplingParameters(SamplingParameters):
# TODO: ====== Get these to memmap dataset (currently ignored) ======
use_loss_masking_spans: bool = False
use_preference_loss_spans: bool = False
use_images: bool = False


@dataclasses.dataclass(kw_only=True)
Expand Down
14 changes: 14 additions & 0 deletions fast_llm/data/preparator/gpt_memmap/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.data.preparator.config import DatasetPreparatorConfig
from fast_llm.data.preprocessing.image_patch import ImagePatchConfig
from fast_llm.data.preprocessing.tokenizer import TokenizerConfig
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.engine.config_utils.runnable import RunnableConfig
Expand Down Expand Up @@ -35,6 +36,10 @@ class LanguageModelSourceConfig(Config):
rejected_span: None | str = Field(
default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional
)
images: None | str = Field(default=None, desc="Field containing images", hint=FieldHint.optional)
image_positions: None | str = Field(
default=None, desc="Field containing image positions in the text.", hint=FieldHint.optional
)

@functools.cached_property
def columns(self) -> list[str]:
Expand All @@ -54,6 +59,11 @@ def has_preference_spans(self) -> bool:
Assert.eq(self.chosen_span is None, self.rejected_span is None)
return self.chosen_span is not None

@functools.cached_property
def has_images(self) -> bool:
Assert.eq(self.images is None, self.image_positions is None)
return self.images is not None

def _validate(self):
super()._validate()
if self.has_preference_spans and self.has_loss_masking_span:
Expand Down Expand Up @@ -177,6 +187,10 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
desc="Configuration for the tokenizer.",
hint=FieldHint.feature,
)
image_patches: ImagePatchConfig = Field(
desc="Configuration for the image patches, if enabled.",
hint=FieldHint.feature,
)
splits: dict[str, float] | None = Field(
default=None,
desc="Split the output dataset into multiple ones (ex, train/valid/test) with the specified ratios."
Expand Down
53 changes: 51 additions & 2 deletions fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from fast_llm.data.preprocessing.tokenizer import Tokenizer
from fast_llm.data.sample.abstract import MemmapIndexDatasetReaderConfig
from fast_llm.data.sample.language_model import LanguageModelSample, LanguageModelWriter
from fast_llm.data.sample.patch import PatchSample
from fast_llm.data.sample.range import RangeSample
from fast_llm.data.sample.token import TokenSample
from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type
Expand All @@ -43,6 +44,7 @@ class SpanType(enum.StrEnum):
loss_masking = "loss_masking"
chosen = "chosen"
rejected = "rejected"
image = "image"


class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]):
Expand Down Expand Up @@ -231,6 +233,30 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample:
text = full_chosen_text + full_rejected_text
all_spans.extend(chosen_spans + rejected_span)

if self._source_schema.has_images:
# Get the images and positions, sorted by position.
images, image_positions = (
zip(
*sorted(
zip(
sample[self._source_schema.images],
sample[self._source_schema.image_positions],
strict=True,
),
key=lambda x: x[1],
)
)
if len(sample[self._source_schema.images]) > 0
else ([], [])
)
# Get the image patches and associated data.
image_patches, image_position_ids, image_token_maps, image_token_ids, patch_counts = (
self._config.image_patches.get_patches(images, self._data_type)
)
patch_count_cumsum = padded_cumsum(patch_counts).tolist()
# Add an empty "span" at each image position so we know where to insert them in the tokenized sequence.
all_spans.extend([(SpanType.image, (position, position)) for position in image_positions])

# Sort the spans by location (begin), keeping track of their type.
# Note: overlapping spans are not supported (explicit assertion in the tokenizer).
span_types, spans = zip(*_sort_spans(all_spans)) if all_spans else ([], [])
Expand All @@ -241,8 +267,26 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample:

# Gather token spans by type.
token_spans_by_type = collections.defaultdict(list)
for span_type, token_span in zip(span_types, token_spans, strict=True):
token_spans_by_type[span_type].append(token_span)
if self._source_schema.has_images:
# Insert the image token ids in the token sequence and shift the spans accordingly.
tokens_shift = 0
image_index = 0
for span_type, (begin, end) in zip(span_types, token_spans, strict=True):
# Account for the tokens already inserted.
begin = begin + tokens_shift
end = end + tokens_shift
if span_type == SpanType.image:
# Shift the token map to the image location.
image_token_maps[patch_count_cumsum[image_index] : patch_count_cumsum[image_index + 1]] += begin
# Insert the placeholder and image break tokens.
tokens = torch.cat([tokens[:begin], image_token_ids[image_index], tokens[begin:]])
tokens_shift += len(image_token_ids[image_index])
image_index += 1
else:
token_spans_by_type[span_type].append((begin, end))
else:
for span_type, token_span in zip(span_types, token_spans, strict=True):
token_spans_by_type[span_type].append(token_span)

sample_size = len(tokens)

Expand All @@ -264,6 +308,11 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample:
if self._source_schema.has_preference_spans
else None
),
(
PatchSample(image_patches, image_token_maps, image_position_ids, sample_size, patch_counts)
if self._source_schema.has_images
else None
),
)

def generate_config_yaml_for_sharded_dst(
Expand Down
185 changes: 185 additions & 0 deletions fast_llm/data/preprocessing/image_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import functools
import io
import math
import typing

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.utils import Assert, div

if typing.TYPE_CHECKING:
import torch


@config_class()
class ImagePatchConfig(Config):
"""
Configuration for the tokenizer.
The tokenizer is needed for FIM and dataset preparation.
"""

height: int = Field(
default=16,
desc="Height of the image patches, in pixels.",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
width: int = Field(
default=16,
desc="Height of the image patches, in pixels.",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
max_image_height: int = Field(
default=1024,
desc="Maximum height of the complete image, in pixels."
"If the original image is larger than this, it will be resized to this height.",
hint=FieldHint.optional,
)
max_image_width: int = Field(
default=1024,
desc="Maximum width of the complete image, in pixels."
"If the original image is larger than this, it will be resized to this width.",
hint=FieldHint.optional,
)
image_break_token: int | None = Field(
default=None,
desc="Add this token at the end of each row of image patches.",
hint=FieldHint.optional,
)
image_end_token: int | None = Field(
default=None,
desc="Add this token after the last patch of each image."
"If `image_break_token` is also defined, only `image_end_token` is added after the last row.",
hint=FieldHint.optional,
)

@property
def num_channels(self) -> int:
# assume 3 channels (RGB) for all images
return 3

@functools.cached_property
def max_patches_height(self) -> int:
return div(self.max_image_height, self.height)

@functools.cached_property
def max_patches_width(self) -> int:
return div(self.max_image_width, self.width)

def _validate(self):
super()._validate()
Assert.gt(self.max_patches_height, 0)
Assert.gt(self.max_patches_width, 0)

def get_patches(
self, images: list[bytes], token_data_type: DataType = DataType.int64
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", list["torch.Tensor"], list[int]]:
import torch

if len(images) > 0:
image_patches, image_positions, image_token_maps, image_token_ids = zip(
*(self._get_patches(image, token_data_type) for image in images)
)
return (
torch.cat(image_patches),
torch.cat(image_positions),
torch.cat(image_token_maps),
image_token_ids,
[len(position_ids) for position_ids in image_positions],
)
else:
# Return empty tensors of appropriate shapes and data types so we can concatenate with other documents.
return (
torch.empty(0, self.num_channels, self.height, self.width, dtype=torch.uint8),
torch.empty(0, 2, dtype=torch.int64),
torch.empty(0, dtype=torch.int64),
[],
[0],
)

def _get_patches(
self, image_bytes: bytes, token_data_type: DataType = DataType.int64
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
import numpy as np
import PIL.Image
import torch

with PIL.Image.open(io.BytesIO(image_bytes)) as image:
if image.mode != "RGB":
# Convert all images to RGB
image = image.convert("RGB")
image = torch.tensor(np.array(image)).permute(2, 0, 1) # HWC to CHW
Assert.eq(image.dtype, torch.uint8)

# Resize to a multiple of patch size smaller or equal to max size.
image = self._resize(image)
num_patches_height = div(image.size(1), self.height)
num_patches_width = div(image.size(2), self.width)
# Convert to patches. (`torch.nn.functional.unfold` not supported for uint8.)
patches = (
image.view(self.num_channels, num_patches_height, self.height, num_patches_width, self.width)
.permute(3, 1, 0, 2, 4)
.flatten(0, 1)
)

positions = torch.stack(
[
torch.arange(num_patches_height).repeat_interleave(num_patches_width),
torch.arange(num_patches_width).repeat(num_patches_height),
],
1,
)

token_map = torch.arange(0, num_patches_width * num_patches_height, dtype=torch.int64)
if self.image_break_token is None:
token_ids = [-100] * (num_patches_width * num_patches_height)
if self.image_end_token is not None:
token_ids.append(self.image_end_token)
else:
token_ids = ([-100] * num_patches_width + [self.image_break_token]) * num_patches_height
token_map += torch.arange(num_patches_height).repeat_interleave(num_patches_width)
if self.image_end_token is not None:
token_ids[-1] = self.image_end_token

return patches, positions, token_map, torch.tensor(token_ids, dtype=token_data_type.torch)

def _resize(self, image: "torch.Tensor") -> "torch.Tensor":
"""
Calculate the new dimensions for resizing an image while maintaining the aspect ratio.
If the image is larger than the max dimensions, it will be resized to fit within them.
If the image is smaller, it will be resized to the nearest multiple of the patch size.
"""
import torchvision.transforms.v2 as torchvision_transforms

target_height, target_width = image.shape[1:]
ratio = max(target_height / self.max_image_height, target_width / self.max_image_width, 1)
target_height = self.height * math.ceil(target_height / self.height / ratio)
target_width = self.width * math.ceil(target_width / self.width / ratio)

# Cap the resizing to half of the current size as a workaround for large images
# See pytorch issue: https://github.com/pytorch/pytorch/issues/103589
while max(image.size(1) / target_height, image.size(2) / target_width) > 2:
image = torchvision_transforms.functional.resize(
image,
size=(math.ceil(image.size(1) / 2), math.ceil(image.size(2) / 2)),
interpolation=torchvision_transforms.InterpolationMode.BICUBIC,
)

# TODO: options for interpolation mode?
return torchvision_transforms.functional.resize(
image, size=(target_height, target_width), interpolation=torchvision_transforms.InterpolationMode.BICUBIC
)


@config_class()
class ImageNormalizationConfig(Config):
scale: float = Field(default=255.0)
# Default values from OpenAI Clip.
mean: tuple[float, float, float] = Field(default=(0.48145466, 0.4578275, 0.40821073))
std: tuple[float, float, float] = Field(default=(0.26862954, 0.26130258, 0.27577711))

def normalize(self, image: "torch.Tensor") -> "torch.Tensor":
import torchvision.transforms.v2 as torchvision_transforms

return torchvision_transforms.functional.normalize(image / self.scale, list(self.mean), list(self.std))
Loading