@@ -354,6 +354,7 @@ def __init__(
354354 self .dim = dim
355355 self .max_res = max_res
356356 self .temperature = temperature
357+ self .linear_bands = linear_bands
357358 self .in_pixels = in_pixels
358359 self .feat_shape = feat_shape
359360 self .ref_feat_shape = ref_feat_shape
@@ -383,17 +384,7 @@ def __init__(
383384 self .pos_embed_cos = None
384385 else :
385386 # cache full sin/cos embeddings if shape provided up front
386- emb_sin , emb_cos = build_rotary_pos_embed (
387- feat_shape = feat_shape ,
388- dim = dim ,
389- max_res = max_res ,
390- linear_bands = linear_bands ,
391- in_pixels = in_pixels ,
392- ref_feat_shape = self .ref_feat_shape ,
393- grid_offset = self .grid_offset ,
394- grid_indexing = self .grid_indexing ,
395- temperature = self .temperature ,
396- )
387+ emb_sin , emb_cos = self ._get_pos_embed_values (feat_shape )
397388 self .bands = None
398389 self .register_buffer (
399390 'pos_embed_sin' ,
@@ -406,6 +397,29 @@ def __init__(
406397 persistent = False ,
407398 )
408399
400+ def _get_pos_embed_values (self , feat_shape : List [int ]):
401+ emb_sin , emb_cos = build_rotary_pos_embed (
402+ feat_shape = feat_shape ,
403+ dim = self .dim ,
404+ max_res = self .max_res ,
405+ temperature = self .temperature ,
406+ linear_bands = self .linear_bands ,
407+ in_pixels = self .in_pixels ,
408+ ref_feat_shape = self .ref_feat_shape ,
409+ grid_offset = self .grid_offset ,
410+ grid_indexing = self .grid_indexing ,
411+ )
412+ return emb_sin , emb_cos
413+
414+ def update_feat_shape (self , feat_shape : List [int ]):
415+ if self .feat_shape is not None and feat_shape != self .feat_shape :
416+ # only update if feat_shape was set and different from previous value
417+ assert self .pos_embed is not None
418+ emb_sin , emb_cos = self ._get_pos_embed_values (feat_shape )
419+ self .pos_embed_sin = emb_sin .to (self .pos_embed_sin .device , self .pos_embed_sin .dtype )
420+ self .pos_embed_cos = emb_cos .to (self .pos_embed_cos .device , self .pos_embed_cos .dtype )
421+ self .feat_shape = feat_shape
422+
409423 def get_embed (self , shape : Optional [List [int ]] = None ):
410424 if shape is not None and self .bands is not None :
411425 # rebuild embeddings every call, use if target shape changes
@@ -453,6 +467,7 @@ def __init__(
453467 self .max_res = max_res
454468 self .temperature = temperature
455469 self .in_pixels = in_pixels
470+ self .linear_bands = linear_bands
456471 self .feat_shape = feat_shape
457472 self .ref_feat_shape = ref_feat_shape
458473 self .grid_offset = grid_offset
@@ -480,27 +495,40 @@ def __init__(
480495 self .pos_embed = None
481496 else :
482497 # cache full sin/cos embeddings if shape provided up front
483- embeds = build_rotary_pos_embed (
484- feat_shape = feat_shape ,
485- dim = dim ,
486- max_res = max_res ,
487- linear_bands = linear_bands ,
488- in_pixels = in_pixels ,
489- ref_feat_shape = self .ref_feat_shape ,
490- grid_offset = self .grid_offset ,
491- grid_indexing = self .grid_indexing ,
492- temperature = self .temperature ,
493- )
494498 self .bands = None
495499 self .register_buffer (
496500 'pos_embed' ,
497- torch . cat ( embeds , - 1 ),
501+ self . _get_pos_embed_values ( feat_shape = feat_shape ),
498502 persistent = False ,
499503 )
500504
505+ def _get_pos_embed_values (self , feat_shape : List [int ]):
506+ embeds = build_rotary_pos_embed (
507+ feat_shape = feat_shape ,
508+ dim = self .dim ,
509+ max_res = self .max_res ,
510+ temperature = self .temperature ,
511+ linear_bands = self .linear_bands ,
512+ in_pixels = self .in_pixels ,
513+ ref_feat_shape = self .ref_feat_shape ,
514+ grid_offset = self .grid_offset ,
515+ grid_indexing = self .grid_indexing ,
516+ )
517+ return torch .cat (embeds , - 1 )
518+
519+ def update_feat_shape (self , feat_shape : List [int ]):
520+ if self .feat_shape is not None and feat_shape != self .feat_shape :
521+ # only update if feat_shape was set and different from previous value
522+ assert self .pos_embed is not None
523+ self .pos_embed = self ._get_pos_embed_values (feat_shape ).to (
524+ device = self .pos_embed .device ,
525+ dtype = self .pos_embed .dtype ,
526+ )
527+ self .feat_shape = feat_shape
528+
501529 def get_embed (self , shape : Optional [List [int ]] = None ):
502530 if shape is not None and self .bands is not None :
503- # rebuild embeddings every call, use if target shape changes
531+ # rebuild embeddings from cached bands every call, use if target shape changes
504532 embeds = build_rotary_pos_embed (
505533 shape ,
506534 self .bands ,
@@ -684,6 +712,7 @@ def __init__(
684712
685713 head_dim = dim // num_heads
686714 assert head_dim % 4 == 0 , f"head_dim must be divisible by 4, got { head_dim } "
715+
687716 freqs = init_random_2d_freqs (
688717 head_dim ,
689718 depth ,
@@ -692,18 +721,32 @@ def __init__(
692721 rotate = True ,
693722 ) # (2, depth, num_heads, head_dim//2)
694723 self .freqs = nn .Parameter (freqs )
724+
695725 if feat_shape is not None :
696726 # cache pre-computed grid
697- t_x , t_y = get_mixed_grid (
698- feat_shape ,
699- grid_indexing = grid_indexing ,
700- device = self .freqs .device
701- )
727+ t_x , t_y = self ._get_grid_values (feat_shape )
702728 self .register_buffer ('t_x' , t_x , persistent = False )
703729 self .register_buffer ('t_y' , t_y , persistent = False )
704730 else :
705731 self .t_x = self .t_y = None
706732
733+ def _get_grid_values (self , feat_shape : Optional [List [int ]]):
734+ t_x , t_y = get_mixed_grid (
735+ feat_shape ,
736+ grid_indexing = self .grid_indexing ,
737+ device = self .freqs .device
738+ )
739+ return t_x , t_y
740+
741+ def update_feat_shape (self , feat_shape : Optional [List [int ]]):
742+ if self .feat_shape is not None and feat_shape != self .feat_shape :
743+ assert self .t_x is not None
744+ assert self .t_y is not None
745+ t_x , t_y = self ._get_grid_values (feat_shape )
746+ self .t_x = t_x .to (self .t_x .device , self .t_x .dtype )
747+ self .t_y = t_y .to (self .t_y .device , self .t_y .dtype )
748+ self .feat_shape = feat_shape
749+
707750 def get_embed (self , shape : Optional [List [int ]] = None ) -> torch .Tensor :
708751 """Generate rotary embeddings for the given spatial shape.
709752
0 commit comments