Skip to content

Commit 48a81b1

Browse files
authored
[FEATURE] Add Chameleon generate example (#1019)
* update UT * update model * fix UT * add to auto model * add generate.py * fix sampling * align with v4.50 and support batch inference * remove generate examples and add preprocess * add copyright
1 parent 90e8a09 commit 48a81b1

27 files changed

+821
-103
lines changed

mindone/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,10 @@
142142
)
143143
from .models.chameleon import (
144144
ChameleonForConditionalGeneration,
145+
ChameleonImageProcessor,
145146
ChameleonModel,
146147
ChameleonPreTrainedModel,
148+
ChameleonProcessor,
147149
ChameleonVQVAE,
148150
)
149151
from .models.clap import (

mindone/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
blip_2,
3131
camembert,
3232
canine,
33+
chameleon,
3334
clap,
3435
clip,
3536
convbert,

mindone/transformers/models/auto/configuration_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
("bit", "BitConfig"),
4444
("blip", "BlipConfig"),
4545
("blip-2", "Blip2Config"),
46+
("chameleon", "ChameleonConfig"),
4647
("camembert", "CamembertConfig"),
4748
("convbert", "ConvBertConfig"),
4849
("clip", "CLIPConfig"),

mindone/transformers/models/auto/image_processing_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
("beit", ("BeitImageProcessor",)),
5151
("blip", ("BlipImageProcessor",)),
5252
("blip-2", ("BlipImageProcessor",)),
53+
("chameleon", ("ChameleonImageProcessor",)),
5354
("clip", ("CLIPImageProcessor",)),
5455
("dpt", ("DPTImageProcessor",)),
5556
("llava_next", ("LlavaNextImageProcessor",)),

mindone/transformers/models/auto/modeling_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@
283283
("aria", "AriaForConditionalGeneration"),
284284
("blip", "BlipForConditionalGeneration"),
285285
("blip-2", "Blip2ForConditionalGeneration"),
286+
("chameleon", "ChameleonForConditionalGeneration"),
286287
("gemma3", "Gemma3ForConditionalGeneration"),
287288
("chameleon", "ChameleonForConditionalGeneration"),
288289
("idefics", "IdeficsForVisionText2Text"),

mindone/transformers/models/auto/processing_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
PROCESSOR_MAPPING_NAMES = OrderedDict(
5151
[
5252
("blip", "BlipProcessor"),
53+
("chameleon", "ChameleonProcessor"),
5354
("llava_next", "LlavaNextProcessor"),
5455
("llava_next_video", "LlavaNextVideoProcessor"),
5556
("llava_onevision", "LlavaOnevisionProcessor"),

mindone/transformers/models/chameleon/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,6 @@
1414
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
17+
from .image_processing_chameleon import *
1718
from .modeling_chameleon import *
19+
from .processing_chameleon import *

mindone/transformers/models/chameleon/image_processing_chameleon.py

Lines changed: 334 additions & 0 deletions
Large diffs are not rendered by default.

mindone/transformers/models/chameleon/modeling_chameleon.py

Lines changed: 172 additions & 41 deletions
Large diffs are not rendered by default.
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# coding=utf-8
2+
# Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
3+
#
4+
# This code is adapted from https://github.com/huggingface/transformers
5+
# with modifications to run transformers on mindspore.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
"""
19+
Processor class for Chameleon.
20+
"""
21+
from typing import List, Optional, Union
22+
23+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
24+
25+
import mindspore as ms
26+
27+
from ...feature_extraction_utils import BatchFeature
28+
from ...image_utils import ImageInput
29+
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack, _validate_images_text_input_order
30+
31+
32+
class ChameleonTextKwargs(TextKwargs, total=False):
33+
return_for_text_completion: bool
34+
35+
36+
class ChameleonProcessorKwargs(ProcessingKwargs, total=False):
37+
text_kwargs: ChameleonTextKwargs
38+
_defaults = {
39+
"text_kwargs": {
40+
"padding": False,
41+
"return_for_text_completion": False,
42+
},
43+
"common_kwargs": {
44+
"return_tensors": "ms",
45+
},
46+
}
47+
48+
49+
class ChameleonProcessor(ProcessorMixin):
50+
r"""
51+
Constructs a Chameleon processor which wraps a Chameleon image processor and a Chameleon tokenizer into a single
52+
processor.
53+
54+
[`ChameleonProcessor`] offers all the functionalities of [`ChameleonImageProcessor`] and [`LlamaTokenizerFast`].
55+
See the [`~ChameleonProcessor.__call__`] and [`~ChameleonProcessor.decode`] for more information.
56+
57+
Args:
58+
image_processor ([`ChameleonImageProcessor`]):
59+
The image processor is a required input.
60+
tokenizer ([`LlamaTokenizerFast`]):
61+
The tokenizer is a required input.
62+
image_seq_length (`int`, *optional*, defaults to 1024):
63+
Sequence length of one image embedding.
64+
image_token (`str`, *optional*, defaults to `"<image>"`):
65+
The special token used to indicate image in the text.
66+
"""
67+
68+
attributes = ["image_processor", "tokenizer"]
69+
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
70+
valid_kwargs = ["image_seq_length", "image_token"]
71+
image_processor_class = "ChameleonImageProcessor"
72+
73+
def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = "<image>"):
74+
self.image_seq_length = image_seq_length
75+
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
76+
self.image_start_token = (
77+
tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "<racm3:break>"
78+
) # fixed tokens for start and end, so can hardcode
79+
self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "<eoss>"
80+
81+
super().__init__(image_processor, tokenizer)
82+
83+
def __call__(
84+
self,
85+
images: Optional[ImageInput] = None,
86+
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
87+
audio=None,
88+
videos=None,
89+
**kwargs: Unpack[ChameleonProcessorKwargs],
90+
) -> BatchFeature:
91+
"""
92+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
93+
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
94+
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
95+
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
96+
of the above two methods for more information.
97+
98+
Args:
99+
images (`PIL.Image.Image`, `np.ndarray`, `mindspore.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[mindspore.Tensor]`):
100+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
101+
tensor. Both channels-first and channels-last formats are supported.
102+
text (`str`, `List[str]`, `List[List[str]]`):
103+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
104+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
105+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
106+
return_tensors (`str` or [`~utils.TensorType`], *optional*):
107+
If set, will return tensors of a particular framework. Acceptable values are:
108+
109+
- `'tf'`: Return TensorFlow `tf.constant` objects.
110+
- `'ms'`: Return PyTorch `mindspore.Tensor` objects.
111+
- `'np'`: Return NumPy `np.ndarray` objects.
112+
- `'jax'`: Return JAX `jnp.ndarray` objects.
113+
114+
Returns:
115+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
116+
117+
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
118+
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
119+
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
120+
`None`).
121+
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
122+
"""
123+
# check if images and text inputs are reversed for BC
124+
images, text = _validate_images_text_input_order(images, text)
125+
if isinstance(text, str):
126+
text = [text]
127+
elif not isinstance(text, list) and not isinstance(text[0], str):
128+
raise TypeError("Invalid input text. Please provide a string, or a list of strings")
129+
if text is None and images is None:
130+
raise ValueError("You must provide either text or images")
131+
132+
output_kwargs = self._merge_kwargs(
133+
ChameleonProcessorKwargs,
134+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
135+
**kwargs,
136+
)
137+
return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False)
138+
139+
# Replace the image token with the expanded image token sequence
140+
prompt_strings = []
141+
one_img_tokens = self.image_start_token + (self.image_token * self.image_seq_length) + self.image_end_token
142+
for sample in text:
143+
sample = sample.replace(self.image_token, one_img_tokens)
144+
if not return_for_text_completion:
145+
sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
146+
prompt_strings.append(sample)
147+
148+
output_kwargs["text_kwargs"].pop("return_tensors", None)
149+
data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors="np")
150+
for k, v in data.items():
151+
data[k] = ms.tensor(v)
152+
153+
if images is not None:
154+
data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
155+
156+
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"]["return_tensors"])
157+
158+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
159+
def batch_decode(self, *args, **kwargs):
160+
"""
161+
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
162+
refer to the docstring of this method for more information.
163+
"""
164+
return self.tokenizer.batch_decode(*args, **kwargs)
165+
166+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
167+
def decode(self, *args, **kwargs):
168+
"""
169+
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
170+
the docstring of this method for more information.
171+
"""
172+
return self.tokenizer.decode(*args, **kwargs)
173+
174+
@property
175+
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
176+
def model_input_names(self):
177+
tokenizer_input_names = self.tokenizer.model_input_names
178+
image_processor_input_names = self.image_processor.model_input_names
179+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
180+
181+
182+
__all__ = ["ChameleonProcessor"]

0 commit comments

Comments
 (0)