diff --git a/python/fate_llm/dataset/hf_dataset.py b/python/fate_llm/dataset/hf_dataset.py index 8fad472..3e7e8ca 100644 --- a/python/fate_llm/dataset/hf_dataset.py +++ b/python/fate_llm/dataset/hf_dataset.py @@ -17,7 +17,8 @@ from typing import Optional, Union, Sequence, Mapping, Dict from datasets import load_dataset, Features, Split, DownloadConfig, DownloadMode, VerificationMode, Version, load_from_disk -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoProcessor +from PIL import Image from fate.ml.nn.dataset.base import Dataset @@ -213,3 +214,28 @@ def tokenize_function(examples): dataset = dataset.map(tokenize_function, batched=True) return dataset + + +class MultimodalDataset(HuggingfaceDataset): + def __init__(self, processor_name_or_path, image_column="image", text_column="text", *args, **kwargs): + super().__init__(*args, **kwargs) + self.processor_name_or_path = processor_name_or_path + self.image_column = image_column + self.text_column = text_column + self.inplace_load = False + + def load(self, file_path): + dataset = super().load(file_path) + return self._post_process(dataset) + + def _post_process(self, dataset): + processor = AutoProcessor.from_pretrained(self.processor_name_or_path, trust_remote_code=self.trust_remote_code) + + def transform(examples): + images = [Image.open(x).convert("RGB") if isinstance(x, str) else x.convert("RGB") for x in examples[self.image_column]] + texts = examples[self.text_column] + inputs = processor(text=texts, images=images, return_tensors="pt", padding=True) + return inputs + + dataset.set_transform(transform) + return dataset diff --git a/python/fate_llm/model_zoo/hf_model.py b/python/fate_llm/model_zoo/hf_model.py index 58fd1f3..1150c94 100644 --- a/python/fate_llm/model_zoo/hf_model.py +++ b/python/fate_llm/model_zoo/hf_model.py @@ -14,7 +14,7 @@ # limitations under the License. # import torch -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoModel class HFAutoModelForCausalLM: @@ -32,3 +32,20 @@ def load(self): self.pretrained_model_name_or_path, *self.model_args, **self.kwargs ) return model + + +class HFAutoModel: + + def __init__(self, pretrained_model_name_or_path, *model_args, **kwargs) -> None: + self.pretrained_model_name_or_path = pretrained_model_name_or_path + self.model_args = model_args + self.kwargs = kwargs + if "torch_dtype" in self.kwargs and self.kwargs["torch_dtype"] != "auto": + dtype = self.kwargs.pop("torch_dtype") + self.kwargs["torch_dtype"] = getattr(torch, dtype) + + def load(self): + model = AutoModel.from_pretrained( + self.pretrained_model_name_or_path, *self.model_args, **self.kwargs + ) + return model