Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion python/fate_llm/dataset/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from typing import Optional, Union, Sequence, Mapping, Dict

from datasets import load_dataset, Features, Split, DownloadConfig, DownloadMode, VerificationMode, Version, load_from_disk
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoProcessor
from PIL import Image

from fate.ml.nn.dataset.base import Dataset

Expand Down Expand Up @@ -213,3 +214,28 @@ def tokenize_function(examples):

dataset = dataset.map(tokenize_function, batched=True)
return dataset


class MultimodalDataset(HuggingfaceDataset):
def __init__(self, processor_name_or_path, image_column="image", text_column="text", *args, **kwargs):
super().__init__(*args, **kwargs)
self.processor_name_or_path = processor_name_or_path
self.image_column = image_column
self.text_column = text_column
self.inplace_load = False

def load(self, file_path):
dataset = super().load(file_path)
return self._post_process(dataset)

def _post_process(self, dataset):
processor = AutoProcessor.from_pretrained(self.processor_name_or_path, trust_remote_code=self.trust_remote_code)

def transform(examples):
images = [Image.open(x).convert("RGB") if isinstance(x, str) else x.convert("RGB") for x in examples[self.image_column]]
texts = examples[self.text_column]
inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
return inputs

dataset.set_transform(transform)
return dataset
19 changes: 18 additions & 1 deletion python/fate_llm/model_zoo/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoModel


class HFAutoModelForCausalLM:
Expand All @@ -32,3 +32,20 @@ def load(self):
self.pretrained_model_name_or_path, *self.model_args, **self.kwargs
)
return model


class HFAutoModel:

def __init__(self, pretrained_model_name_or_path, *model_args, **kwargs) -> None:
self.pretrained_model_name_or_path = pretrained_model_name_or_path
self.model_args = model_args
self.kwargs = kwargs
if "torch_dtype" in self.kwargs and self.kwargs["torch_dtype"] != "auto":
dtype = self.kwargs.pop("torch_dtype")
self.kwargs["torch_dtype"] = getattr(torch, dtype)

def load(self):
model = AutoModel.from_pretrained(
self.pretrained_model_name_or_path, *self.model_args, **self.kwargs
)
return model