@@ -515,16 +515,19 @@ def patchify_and_embed(
515515 start = (cap_ori_len + cap_padding_len + 1 , 0 , 0 ),
516516 device = device ,
517517 ).flatten (0 , 2 )
518- image_padding_pos_ids = (
519- self .create_coordinate_grid (
520- size = (1 , 1 , 1 ),
521- start = (0 , 0 , 0 ),
522- device = device ,
518+ if image_padding_len > 0 :
519+ image_padding_pos_ids = (
520+ self .create_coordinate_grid (
521+ size = (1 , 1 , 1 ),
522+ start = (0 , 0 , 0 ),
523+ device = device ,
524+ )
525+ .flatten (0 , 2 )
526+ .repeat (image_padding_len , 1 )
523527 )
524- .flatten (0 , 2 )
525- .repeat (image_padding_len , 1 )
526- )
527- image_padded_pos_ids = torch .cat ([image_ori_pos_ids , image_padding_pos_ids ], dim = 0 )
528+ image_padded_pos_ids = torch .cat ([image_ori_pos_ids , image_padding_pos_ids ], dim = 0 )
529+ else :
530+ image_padded_pos_ids = image_ori_pos_ids
528531 all_image_pos_ids .append (image_padded_pos_ids )
529532 # pad mask
530533 all_image_pad_mask .append (
@@ -534,10 +537,10 @@ def patchify_and_embed(
534537 torch .ones ((image_padding_len ,), dtype = torch .bool , device = device ),
535538 ],
536539 dim = 0 ,
537- )
540+ ) if image_padding_len > 0 else torch . zeros (( image_ori_len ,), dtype = torch . bool , device = device )
538541 )
539542 # padded feature
540- image_padded_feat = torch .cat ([image , image [- 1 :].repeat (image_padding_len , 1 )], dim = 0 )
543+ image_padded_feat = torch .cat ([image , image [- 1 :].repeat (image_padding_len , 1 )], dim = 0 ) if image_padding_len > 0 else image
541544 all_image_out .append (image_padded_feat )
542545
543546 return (
0 commit comments