|
| 1 | +# %% |
| 2 | + |
| 3 | +# Adapted from meta jepa |
| 4 | + |
| 5 | +import math |
| 6 | +import numpy as np |
| 7 | +from multiprocessing import Value |
| 8 | + |
| 9 | +class MaskCollator(object): |
| 10 | + |
| 11 | + def __init__( |
| 12 | + self, |
| 13 | + cfgs_mask, |
| 14 | + crop_size=(224, 224), |
| 15 | + patch_size=(16, 16), |
| 16 | + ): |
| 17 | + super(MaskCollator, self).__init__() |
| 18 | + |
| 19 | + self.mask_generators = [] |
| 20 | + for m in cfgs_mask: |
| 21 | + mask_generator = _MaskGenerator( |
| 22 | + crop_size=crop_size, |
| 23 | + patch_size=patch_size, |
| 24 | + pred_mask_scale=m.get('spatial_scale'), |
| 25 | + aspect_ratio=m.get('aspect_ratio'), |
| 26 | + npred=m.get('num_blocks'), |
| 27 | + max_keep=m.get('max_keep', None), |
| 28 | + ) |
| 29 | + self.mask_generators.append(mask_generator) |
| 30 | + |
| 31 | + def step(self): |
| 32 | + for mask_generator in self.mask_generators: |
| 33 | + mask_generator.step() |
| 34 | + |
| 35 | + def __call__(self, batch): |
| 36 | + |
| 37 | + batch_size = len(batch) |
| 38 | + |
| 39 | + collated_masks_pred, collated_masks_enc = [], [] |
| 40 | + for i, mask_generator in enumerate(self.mask_generators): |
| 41 | + masks_enc, masks_pred = mask_generator(batch_size) |
| 42 | + collated_masks_enc.append(masks_enc) |
| 43 | + collated_masks_pred.append(masks_pred) |
| 44 | + |
| 45 | + return collated_masks_enc, collated_masks_pred |
| 46 | + |
| 47 | + |
| 48 | +class _MaskGenerator(object): |
| 49 | + |
| 50 | + def __init__( |
| 51 | + self, |
| 52 | + crop_size=(224, 224), |
| 53 | + patch_size=(16, 16), |
| 54 | + pred_mask_scale=(0.2, 0.8), |
| 55 | + aspect_ratio=(0.3, 3.0), |
| 56 | + npred=1, |
| 57 | + max_keep=None, |
| 58 | + ): |
| 59 | + super(_MaskGenerator, self).__init__() |
| 60 | + if not isinstance(crop_size, tuple): |
| 61 | + crop_size = (crop_size, ) * 2 |
| 62 | + self.crop_size = crop_size |
| 63 | + self.height, self.width = crop_size[0] // patch_size[0], crop_size[1] // patch_size[1] |
| 64 | + |
| 65 | + self.patch_size = patch_size |
| 66 | + self.aspect_ratio = aspect_ratio |
| 67 | + self.pred_mask_scale = pred_mask_scale |
| 68 | + self.npred = npred |
| 69 | + self.max_keep = max_keep |
| 70 | + self._itr_counter = Value('i', -1) # collator is shared across worker processes |
| 71 | + |
| 72 | + def step(self): |
| 73 | + i = self._itr_counter |
| 74 | + with i.get_lock(): |
| 75 | + i.value = (i.value + 1) % 2**16 |
| 76 | + v = i.value |
| 77 | + return v |
| 78 | + |
| 79 | + def _sample_block_size( |
| 80 | + self, |
| 81 | + rng: np.random.RandomState, |
| 82 | + scale, |
| 83 | + aspect_ratio_scale |
| 84 | + ): |
| 85 | + # -- Sample spatial block mask scale |
| 86 | + _rand = rng.random() |
| 87 | + min_s, max_s = scale |
| 88 | + spatial_mask_scale = min_s + _rand * (max_s - min_s) |
| 89 | + spatial_num_keep = int(self.height * self.width * spatial_mask_scale) |
| 90 | + |
| 91 | + # -- Sample block aspect-ratio |
| 92 | + _rand = rng.random() |
| 93 | + min_ar, max_ar = aspect_ratio_scale |
| 94 | + aspect_ratio = min_ar + _rand * (max_ar - min_ar) |
| 95 | + |
| 96 | + # -- Compute block height and width (given scale and aspect-ratio) |
| 97 | + h = int(round(math.sqrt(spatial_num_keep * aspect_ratio))) |
| 98 | + w = int(round(math.sqrt(spatial_num_keep / aspect_ratio))) |
| 99 | + h = min(h, self.height) |
| 100 | + w = min(w, self.width) |
| 101 | + |
| 102 | + return (h, w) |
| 103 | + |
| 104 | + def _sample_block_mask(self, b_size, rng: np.random.RandomState): |
| 105 | + h, w = b_size |
| 106 | + top = rng.randint(0, self.height - h + 1) |
| 107 | + left = rng.randint(0, self.width - w + 1) |
| 108 | + |
| 109 | + mask = np.ones((self.height, self.width), dtype=np.int32) |
| 110 | + mask[top:top+h, left:left+w] = 0 |
| 111 | + |
| 112 | + return mask |
| 113 | + |
| 114 | + def __call__(self, batch_size): |
| 115 | + """ |
| 116 | + Create encoder and predictor masks when collating imgs into a batch |
| 117 | + # 1. sample pred block size using seed |
| 118 | + # 2. sample several pred block locations for each image (w/o seed) |
| 119 | + # 3. return pred masks and complement (enc mask) |
| 120 | + """ |
| 121 | + seed = self.step() |
| 122 | + rng = np.random.RandomState(seed) |
| 123 | + p_size = self._sample_block_size( |
| 124 | + rng=rng, |
| 125 | + scale=self.pred_mask_scale, |
| 126 | + aspect_ratio_scale=self.aspect_ratio, |
| 127 | + ) |
| 128 | + |
| 129 | + collated_masks_pred, collated_masks_enc = [], [] |
| 130 | + min_keep_enc = min_keep_pred = self.height * self.width |
| 131 | + for _ in range(batch_size): |
| 132 | + |
| 133 | + empty_context = True |
| 134 | + while empty_context: |
| 135 | + # Create a mask for this sample |
| 136 | + mask_e = np.ones((self.height, self.width), dtype=np.int32) |
| 137 | + for _ in range(self.npred): |
| 138 | + mask_e *= self._sample_block_mask(p_size, rng) |
| 139 | + mask_e = mask_e.flatten() |
| 140 | + |
| 141 | + mask_p = np.where(mask_e == 0)[0] |
| 142 | + mask_e = np.where(mask_e != 0)[0] |
| 143 | + |
| 144 | + empty_context = len(mask_e) == 0 |
| 145 | + if not empty_context: |
| 146 | + min_keep_pred = min(min_keep_pred, len(mask_p)) |
| 147 | + min_keep_enc = min(min_keep_enc, len(mask_e)) |
| 148 | + collated_masks_pred.append(mask_p) |
| 149 | + collated_masks_enc.append(mask_e) |
| 150 | + |
| 151 | + if self.max_keep is not None: |
| 152 | + min_keep_enc = min(min_keep_enc, self.max_keep) |
| 153 | + |
| 154 | + # Truncate arrays to the minimum length to create uniform arrays |
| 155 | + collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred] |
| 156 | + collated_masks_pred = np.array(collated_masks_pred) |
| 157 | + |
| 158 | + collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc] |
| 159 | + collated_masks_enc = np.array(collated_masks_enc) |
| 160 | + |
| 161 | + return collated_masks_enc, collated_masks_pred |
| 162 | + |
| 163 | + |
0 commit comments