@@ -67,6 +67,15 @@ def position_ids_in_meshgrid(patch_embeds_list, max_width, flatten=True):
6767 return torch .stack (positions )
6868
6969
70+
71+ # from bagel
72+ def get_flattened_position_ids_extrapolate (img_h , img_w , patch_size , max_num_patches_per_side ):
73+ num_patches_h , num_patches_w = img_h // patch_size , img_w // patch_size
74+ coords_h = torch .arange (0 , num_patches_h )
75+ coords_w = torch .arange (0 , num_patches_w )
76+ pos_ids = (coords_h [:, None ] * max_num_patches_per_side + coords_w ).flatten ()
77+ return pos_ids
78+
7079def create_block_diagonal_mask (lengths , device ):
7180 """
7281 Create a block diagonal mask based on sequence lengths.
@@ -88,6 +97,18 @@ def create_block_diagonal_mask(lengths, device):
8897 return mask .to (device )
8998
9099
100+ # grabbed from bagel repo
101+
102+ def patchify (image , patch_size ):
103+ p = patch_size
104+ c , h , w = image .shape
105+ assert h % p == 0 and w % p == 0
106+ image = image .reshape (c , h // p , p , w // p , p )
107+ image = torch .einsum ("chpwq->hwpqc" , image )
108+ image = image .reshape (- 1 , p ** 2 * c )
109+ return image
110+
111+
91112class VisionEncoder (nn .Module ):
92113 def __init__ (self , encoder_config , running_config = None ):
93114 super (VisionEncoder , self ).__init__ ()
@@ -99,12 +120,18 @@ def __init__(self, encoder_config, running_config=None):
99120 )
100121 else :
101122 self .rope = build_rope (encoder_config , mode = "2d" )
102- self .patch_conv = nn .Conv2d (
103- in_channels = encoder_config .num_channels ,
104- out_channels = encoder_config .hidden_size ,
105- kernel_size = encoder_config .patch_size ,
106- stride = encoder_config .patch_size ,
107- bias = encoder_config .patch_conv_bias ,
123+ # self.patch_conv = nn.Conv2d(
124+ # in_channels=encoder_config.num_channels,
125+ # out_channels=encoder_config.hidden_size,
126+ # kernel_size=encoder_config.patch_size,
127+ # stride=encoder_config.patch_size,
128+ # bias=encoder_config.patch_conv_bias,
129+ # )
130+ # linear patch conv for bagel
131+ self .patch_conv = nn .Linear (
132+ encoder_config .patch_size * encoder_config .patch_size * encoder_config .num_channels ,
133+ encoder_config .hidden_size ,
134+ bias = True ,
108135 )
109136 if encoder_config .layernorm_pre :
110137 self .ln_pre = RMSNorm (encoder_config .hidden_size , eps = 1e-5 )
@@ -133,7 +160,8 @@ def from_config(cls, encoder_config, running_config=None):
133160
134161 @property
135162 def max_patches_per_side (self ):
136- return self .encoder_config .image_size // self .encoder_config .patch_size
163+ return 70 # hardcoded bagel value
164+ # return self.encoder_config.image_size // self.encoder_config.patch_size
137165
138166 @property
139167 def device (self ):
@@ -151,8 +179,10 @@ def forward(self, images):
151179 # TODO add as @property somewhere
152180 dtype = next (self .parameters ()).dtype
153181
182+ pixel_values = [patchify (img , self .encoder_config .patch_size ) for img in images ]
183+
154184 # pass images through initial convolution independently (because they may have different sizes)
155- patch_embeds_list = [self .patch_conv (img .to (dtype )) for img in images ]
185+ patch_embeds_list = [self .patch_conv (pv .to (dtype )) for pv in pixel_values ]
156186
157187 if self .ln_pre is not None : # pixtral / mistral
158188 # flatten H+W then change to (H+W, C) and stack all images of ex
@@ -171,17 +201,32 @@ def forward(self, images):
171201 patch_embeds = patch_embeds .flatten (2 ).transpose (1 , 2 )
172202 mask = None
173203
204+ patch_embeds = patch_embeds .transpose (1 , 2 ) # (N_img, Seqlen, D)
205+
174206 # positional embeddings
175- positions = position_ids_in_meshgrid (
176- patch_embeds_list ,
177- max_width = self .encoder_config .image_size // self .encoder_config .patch_size ,
178- flatten = self .ln_pre is not None , # dirty flag need to improve
179- ).to (self .device )
207+ # positions = position_ids_in_meshgrid(
208+ # # patch_embeds_list,
209+ # images,
210+ # max_width=self.encoder_config.image_size // self.encoder_config.patch_size,
211+ # flatten=self.ln_pre is not None, # dirty flag need to improve
212+ # ).to(self.device)
213+ positions = torch .cat ([
214+ get_flattened_position_ids_extrapolate (
215+ img .shape [- 2 ],
216+ img .shape [- 1 ],
217+ self .encoder_config .patch_size ,
218+ self .max_patches_per_side ,
219+
220+ )
221+ for img in images
222+ ], axis = 0 ).unsqueeze (0 ).to (self .device )
223+
180224 # TODO: make this cleaner
181225 if hasattr (self , "position_embeddings" ):
182226 # this is only used for rope
183227 position_embeddings = None
184- patch_embeds += self .position_embeddings (positions )
228+ pos_embeds = self .position_embeddings (positions )
229+ patch_embeds += pos_embeds
185230 else :
186231 position_embeddings = self .rope .update (
187232 patch_embeds .size (1 ),
@@ -197,7 +242,7 @@ def forward(self, images):
197242 if self .post_layernorm is not None :
198243 out = self .post_layernorm (out )
199244
200- return out
245+ return out , positions
201246
202247
203248# Multi-Modal Projector
@@ -266,4 +311,5 @@ def from_config(cls, model_config, running_config=None):
266311str2adapter = {
267312 "llava" : VisionLanguageAdapter ,
268313 "gemma3" : Gemma3MultiModalProjector ,
314+ "bagel" : VisionLanguageAdapter ,
269315}
0 commit comments