Skip to content

Commit e02ecd2

Browse files
committed
feat: data preprocessing pipeline
1 parent a73f3e8 commit e02ecd2

File tree

8 files changed

+1350
-252
lines changed

8 files changed

+1350
-252
lines changed

data-processing.py

Lines changed: 480 additions & 0 deletions
Large diffs are not rendered by default.

dataset_preprocessing.ipynb

Lines changed: 777 additions & 245 deletions
Large diffs are not rendered by default.

datasets/gcsfuse.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,6 @@ mkdir -p $MOUNT_PATH
4747
# Grain uses _PROCESS_MANAGEMENT_MAX_THREADS = 64 (https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)
4848
# Please make sure max-conns-per-host > grain_worker_count * _PROCESS_MANAGEMENT_MAX_THREADS
4949

50-
gcsfuse -o ro --implicit-dirs --http-client-timeout=5s --max-conns-per-host=0 --max-idle-conns-per-host=100000 \
51-
--experimental-enable-json-read --kernel-list-cache-ttl-secs=-1 -o ro --config-file=$HOME/gcsfuse.yml \
52-
--log-file=$HOME/gcsfuse.json "$DATASET_GCS_BUCKET" "$MOUNT_PATH"
50+
gcsfuse -o rw --implicit-dirs --http-client-timeout=5s --max-conns-per-host=0 --max-idle-conns-per-host=100000 \
51+
--experimental-enable-json-read --kernel-list-cache-ttl-secs=-1 -o rw --config-file=$HOME/gcsfuse.yml \
52+
--log-file=$HOME/gcsfuse.json "$DATASET_GCS_BUCKET" "$MOUNT_PATH"

flaxdiff/data/dataloaders.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,12 @@ def fallback_batch(batch, media_type="image"):
251251
else: # Default to image
252252
return image_collate
253253

254+
class CaptionDeletionTransform(pygrain.MapTransform):
255+
def map(self, element):
256+
"""Delete the caption from the element."""
257+
if "caption" in element:
258+
del element["caption"]
259+
return element
254260

255261
def get_dataset_grain(
256262
data_name="cc12m",
@@ -288,6 +294,7 @@ def get_dataset_grain(
288294
dataset = datasetMap[data_name]
289295
data_source = dataset["source"](dataset_source)
290296
augmenter = dataset["augmenter"](image_scale, method)
297+
filters = dataset.get("filter", None)(image_scale)
291298

292299
local_batch_size = batch_size // jax.process_count()
293300

@@ -310,8 +317,14 @@ def get_dataset_grain(
310317
def get_trainset():
311318
transformations = [
312319
augmenter(),
313-
pygrain.Batch(local_batch_size, drop_remainder=True),
314320
]
321+
322+
if filters:
323+
print("Adding filters to transformations")
324+
transformations.append(filters())
325+
326+
transformations.append(CaptionDeletionTransform())
327+
transformations.append(pygrain.Batch(local_batch_size, drop_remainder=True))
315328

316329
loader = pygrain.DataLoader(
317330
data_source=data_source,

flaxdiff/data/dataset_map.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# ---------------------------------------------------------------------------------
99

1010
from .sources.images import data_source_tfds, tfds_augmenters, data_source_gcs
11-
from .sources.images import data_source_combined_gcs, gcs_augmenters
11+
from .sources.images import data_source_combined_gcs, gcs_augmenters, gcs_filters
1212

1313
# Configure the following for your datasets
1414
datasetMap = {
@@ -23,6 +23,7 @@
2323
"laiona_coco": {
2424
"source": data_source_gcs('datasets/laion12m+mscoco'),
2525
"augmenter": gcs_augmenters,
26+
"filter": gcs_filters,
2627
},
2728
"aesthetic_coyo": {
2829
"source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),

flaxdiff/data/sources/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,18 @@ def create_transform(self, **kwargs) -> Callable[[], pygrain.MapTransform]:
6262
"""
6363
pass
6464

65+
@abstractmethod
66+
def create_filter(self, **kwargs) -> Callable[[], pygrain.FilterTransform]:
67+
"""Create a filter function for the data.
68+
69+
Args:
70+
**kwargs: Additional arguments for the filter.
71+
72+
Returns:
73+
A callable that returns a pygrain.FilterTransform instance.
74+
"""
75+
pass
76+
6577
@staticmethod
6678
def create(augmenter_type: str, **kwargs) -> 'DataAugmenter':
6779
"""Factory method to create a data augmenter of the specified type.

flaxdiff/data/sources/images.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,11 @@ def map(self, element) -> Dict[str, jnp.array]:
168168
}
169169

170170
return TFDSTransform
171-
171+
172+
def create_filter(self, image_scale: int = 256):
173+
class FilterTransform(pygrain.FilterTransform):
174+
def map(self, element) -> bool:
175+
return True
172176
"""
173177
Batch structure:
174178
{
@@ -237,7 +241,6 @@ def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> An
237241
records_path) if 'array_record' in i]
238242
return pygrain.ArrayRecordDataSource(records)
239243

240-
241244
class ImageGCSAugmenter(DataAugmenter):
242245
"""Augmenter for GCS image datasets."""
243246

@@ -290,13 +293,60 @@ def map(self, element) -> Dict[str, jnp.array]:
290293
results = self.auto_tokenize(caption)
291294
return {
292295
"image": image,
296+
"caption": caption,
293297
"text": {
294298
"input_ids": results['input_ids'][0],
295299
"attention_mask": results['attention_mask'][0],
296300
}
297301
}
298302

299303
return GCSTransform
304+
305+
def create_filter(self, image_scale: int = 256):
306+
import torch.nn.functional as F
307+
class FilterTransform(pygrain.FilterTransform):
308+
"""
309+
Filter transform for GCS data source.
310+
"""
311+
def __init__(self, model=None, processor=None, method=cv2.INTER_AREA):
312+
super().__init__()
313+
self.image_scale = image_scale
314+
if model is None:
315+
from transformers import AutoProcessor, CLIPVisionModelWithProjection, FlaxCLIPModel, CLIPModel
316+
model_name = "openai/clip-vit-base-patch32"
317+
model = CLIPModel.from_pretrained(model_name)
318+
processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
319+
self.method = method
320+
self.model = model
321+
self.processor = processor
322+
323+
# def _filter_(pixel_values, input_ids):
324+
# image_embeds = self.model.get_image_features(pixel_values=pixel_values)
325+
# text_embeds = self.model.get_text_features(input_ids=input_ids)
326+
# image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
327+
# text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
328+
# similarity = jnp.sum(image_embeds * text_embeds, axis=-1)
329+
# return jnp.all(similarity >= 0.25)
330+
331+
# self._filter_ = _filter_
332+
333+
def filter(self, data: Dict[str, Any]) -> bool:
334+
images = [data['image']]
335+
texts = [data['caption']]
336+
inputs = self.processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)
337+
# result = self._filter_(
338+
# pixel_values=inputs['pixel_values'],
339+
# input_ids=inputs['input_ids']
340+
# )
341+
# return result
342+
343+
image_embeds = self.model.get_image_features(pixel_values=inputs['pixel_values'])
344+
text_embeds = self.model.get_text_features(input_ids=inputs['input_ids'])
345+
similarity = F.cosine_similarity(image_embeds, text_embeds)
346+
# Filter out images with similarity less than 0.25
347+
return similarity[0] >= 0.25
348+
349+
return FilterTransform
300350

301351

302352
# ----------------------------------------------------------------------------------
@@ -333,3 +383,8 @@ def gcs_augmenters(image_scale, method):
333383
"""Legacy function for GCS augmenters."""
334384
augmenter = ImageGCSAugmenter()
335385
return augmenter.create_transform(image_scale=image_scale, method=method)
386+
387+
def gcs_filters(image_scale):
388+
"""Legacy function for GCS Filters."""
389+
augmenter = ImageGCSAugmenter()
390+
return augmenter.create_filter(image_scale=image_scale)

flaxdiff/data/sources/videos.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,11 @@ def random_map(self, element, rng: np.random.Generator) -> Dict[str, jnp.array]:
216216

217217
return AudioVideoTransform
218218

219+
220+
def create_filter(self, image_scale: int = 256):
221+
class FilterTransform(pygrain.FilterTransform):
222+
def map(self, element) -> bool:
223+
return True
219224

220225
# ----------------------------------------------------------------------------------
221226
# Helper functions for video datasets

0 commit comments

Comments
 (0)