Skip to content

Commit 1450bf2

Browse files
authored
feat(transformers) add Idefics series models (#1159)
* add idefics, idefics2 * added infer example; passed UT * add processors; add copyright * synchronize PR #1127 * fix bug; support fa * delete example scripts; example in model comments * linting
1 parent 4e9de5c commit 1450bf2

File tree

18 files changed

+5282
-70
lines changed

18 files changed

+5282
-70
lines changed

mindone/transformers/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
AutoModelForCausalLM,
5757
AutoModelForImageTextToText,
5858
AutoModelForMaskedLM,
59+
AutoModelForVision2Seq,
5960
AutoProcessor,
6061
)
6162
from .models.bart import (
@@ -251,6 +252,14 @@
251252
HieraPreTrainedModel,
252253
)
253254
from .models.hubert import HubertForCTC, HubertForSequenceClassification, HubertModel, HubertPreTrainedModel
255+
from .models.idefics import (
256+
IdeficsForVisionText2Text,
257+
IdeficsImageProcessor,
258+
IdeficsModel,
259+
IdeficsPreTrainedModel,
260+
IdeficsProcessor,
261+
)
262+
from .models.idefics2 import Idefics2ForConditionalGeneration, Idefics2Model, Idefics2PreTrainedModel
254263
from .models.idefics3 import (
255264
Idefics3ForConditionalGeneration,
256265
Idefics3Model,

mindone/transformers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
granitemoeshared,
4848
hiera,
4949
hubert,
50+
idefics,
51+
idefics2,
5052
idefics3,
5153
ijepa,
5254
imagegpt,

mindone/transformers/models/auto/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,11 @@
1717
from .configuration_auto import AutoConfig
1818
from .feature_extraction_auto import AutoFeatureExtractor
1919
from .image_processing_auto import AutoImageProcessor
20-
from .modeling_auto import AutoModel, AutoModelForCausalLM, AutoModelForImageTextToText, AutoModelForMaskedLM
20+
from .modeling_auto import (
21+
AutoModel,
22+
AutoModelForCausalLM,
23+
AutoModelForImageTextToText,
24+
AutoModelForMaskedLM,
25+
AutoModelForVision2Seq,
26+
)
2127
from .processing_auto import AutoProcessor

mindone/transformers/models/auto/processing_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
[
5252
("blip", "BlipProcessor"),
5353
("chameleon", "ChameleonProcessor"),
54+
("idefics", "IdeficsProcessor"),
5455
("llava_next", "LlavaNextProcessor"),
5556
("llava_next_video", "LlavaNextVideoProcessor"),
5657
("llava_onevision", "LlavaOnevisionProcessor"),
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# This code is adapted from https://github.com/huggingface/transformers
4+
# with modifications to run transformers on mindspore.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
from .image_processing_idefics import *
19+
from .modeling_idefics import IdeficsForVisionText2Text, IdeficsModel, IdeficsPreTrainedModel
20+
from .processing_idefics import *
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# coding=utf-8
2+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
3+
#
4+
# This code is adapted from https://github.com/huggingface/transformers
5+
# with modifications to run transformers on mindspore.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
"""Image processor class for Idefics."""
19+
20+
from typing import Callable, Optional, Union
21+
22+
from PIL import Image
23+
24+
from mindspore import mint
25+
26+
from ...image_processing_utils import BaseImageProcessor, BatchFeature
27+
from ...image_transforms import resize, to_channel_dimension_format
28+
from ...image_utils import (
29+
ChannelDimension,
30+
ImageInput,
31+
PILImageResampling,
32+
make_list_of_images,
33+
to_numpy_array,
34+
valid_images,
35+
)
36+
from ...utils import TensorType
37+
38+
IDEFICS_STANDARD_MEAN = [0.48145466, 0.4578275, 0.40821073]
39+
IDEFICS_STANDARD_STD = [0.26862954, 0.26130258, 0.27577711]
40+
41+
42+
def convert_to_rgb(image):
43+
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
44+
# for transparent images. The call to `alpha_composite` handles this case
45+
if image.mode == "RGB":
46+
return image
47+
48+
image_rgba = image.convert("RGBA")
49+
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
50+
alpha_composite = Image.alpha_composite(background, image_rgba)
51+
alpha_composite = alpha_composite.convert("RGB")
52+
return alpha_composite
53+
54+
55+
class IdeficsImageProcessor(BaseImageProcessor):
56+
r"""
57+
Constructs a Idefics image processor.
58+
59+
Args:
60+
image_size (`int`, *optional*, defaults to 224):
61+
Resize to image size
62+
image_mean (`float` or `list[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
63+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
64+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
65+
overridden by the `image_mean` parameter in the `preprocess` method.
66+
image_std (`float` or `list[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
67+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
68+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
69+
Can be overridden by the `image_std` parameter in the `preprocess` method.
70+
image_num_channels (`int`, *optional*, defaults to 3):
71+
Number of image channels.
72+
do_rescale (`bool`, *optional*, defaults to `True`):
73+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
74+
the `preprocess` method.
75+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
76+
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
77+
method.
78+
"""
79+
80+
model_input_names = ["pixel_values"]
81+
82+
def __init__(
83+
self,
84+
image_size: int = 224,
85+
image_mean: Optional[Union[float, list[float]]] = None,
86+
image_std: Optional[Union[float, list[float]]] = None,
87+
image_num_channels: Optional[int] = 3,
88+
do_rescale: bool = True,
89+
rescale_factor: Union[int, float] = 1 / 255,
90+
**kwargs,
91+
) -> None:
92+
super().__init__(**kwargs)
93+
94+
self.image_size = image_size
95+
self.image_num_channels = image_num_channels
96+
self.image_mean = image_mean if image_mean is not None else IDEFICS_STANDARD_MEAN
97+
self.image_std = image_std if image_std is not None else IDEFICS_STANDARD_STD
98+
self.do_rescale = do_rescale
99+
self.rescale_factor = rescale_factor
100+
101+
def preprocess(
102+
self,
103+
images: ImageInput,
104+
image_num_channels: Optional[int] = 3,
105+
image_size: Optional[dict[str, int]] = None,
106+
image_mean: Optional[Union[float, list[float]]] = None,
107+
image_std: Optional[Union[float, list[float]]] = None,
108+
transform: Optional[Callable] = None,
109+
do_rescale: Optional[bool] = None,
110+
rescale_factor: Optional[float] = None,
111+
return_tensors: Optional[Union[str, TensorType]] = TensorType.MINDSPORE,
112+
**kwargs,
113+
) -> TensorType:
114+
"""
115+
Preprocess a batch of images.
116+
117+
Args:
118+
images (`ImageInput`):
119+
A list of images to preprocess.
120+
image_size (`int`, *optional*, defaults to `self.image_size`):
121+
Resize to image size
122+
image_num_channels (`int`, *optional*, defaults to `self.image_num_channels`):
123+
Number of image channels.
124+
image_mean (`float` or `list[float]`, *optional*, defaults to `IDEFICS_STANDARD_MEAN`):
125+
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
126+
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can
127+
be overridden by the `image_mean` parameter in the `preprocess` method.
128+
image_std (`float` or `list[float]`, *optional*, defaults to `IDEFICS_STANDARD_STD`):
129+
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
130+
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess`
131+
method. Can be overridden by the `image_std` parameter in the `preprocess` method.
132+
transform (`Callable`, *optional*, defaults to `None`):
133+
A custom transform function that accepts a single image can be passed for training. For example,
134+
`transforms.Compose` can be used to compose multiple transforms. If `None` - an inference mode is
135+
assumed - and then a preset of inference-specific transforms will be applied to the images
136+
do_rescale (`bool`, *optional*, defaults to `True`):
137+
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
138+
the `preprocess` method.
139+
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
140+
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
141+
method.
142+
143+
Returns:
144+
a MindSpore tensor of the processed images
145+
146+
"""
147+
image_size = image_size if image_size is not None else self.image_size
148+
image_num_channels = image_num_channels if image_num_channels is not None else self.image_num_channels
149+
image_mean = image_mean if image_mean is not None else self.image_mean
150+
image_std = image_std if image_std is not None else self.image_std
151+
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
152+
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
153+
size = (image_size, image_size)
154+
155+
if isinstance(images, list) and len(images) == 0:
156+
return []
157+
158+
images = make_list_of_images(images)
159+
160+
if not valid_images(images):
161+
raise ValueError(
162+
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
163+
"ms.Tensor, tf.Tensor or jax.ndarray."
164+
)
165+
166+
# For training a user needs to pass their own set of transforms as a Callable.
167+
# For reference this is what was used in the original IDEFICS training:
168+
# transform = transforms.Compose([
169+
# convert_to_rgb,
170+
# vision.RandomResizedCrop((size, size), scale=(0.9, 1.0), interpolation=vision.Inter.BICUBIC),
171+
# vision.ToTensor(),
172+
# vision.Normalize(mean=image_mean, std=image_std),
173+
# ])
174+
if transform is not None:
175+
images = [transform(x) for x in images]
176+
return mint.stack(images)
177+
178+
# for inference we do the exact transforms that were used to train IDEFICS
179+
images = [convert_to_rgb(x) for x in images]
180+
# further transforms expect numpy arrays
181+
images = [to_numpy_array(x) for x in images]
182+
images = [resize(x, size, resample=PILImageResampling.BICUBIC) for x in images]
183+
images = [self.rescale(image=image, scale=rescale_factor) for image in images]
184+
images = [self.normalize(x, mean=image_mean, std=image_std) for x in images]
185+
images = [to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images]
186+
images = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)["pixel_values"]
187+
188+
return images
189+
190+
191+
__all__ = ["IdeficsImageProcessor"]

0 commit comments

Comments
 (0)