77
88from diffsynth_engine .models .base import PreTrainedModel
99from diffsynth_engine .models .basic .transformer_helper import RMSNorm
10- from diffsynth_engine .models .basic . attention import attention
10+ from diffsynth_engine .models .basic import attention as attention_ops
1111from diffsynth_engine .models .utils import no_init_weights
1212from diffsynth_engine .utils .cache import Cache , DynamicCache
1313from diffsynth_engine .utils import logging
@@ -152,17 +152,15 @@ def __init__(
152152 self ,
153153 dim : int = 80 ,
154154 theta : float = 10000.0 ,
155- device : str = "cuda:0" ,
156- dtype : torch .dtype = torch .bfloat16 ,
157155 ):
158156 super ().__init__ ()
159- with torch .device (device ):
160- inv_freq = 1.0 / (theta ** (torch .arange (0 , dim , 2 , dtype = torch .float ) / dim ))
161- self .register_buffer ("inv_freq" , inv_freq , persistent = False )
157+ with torch .device ("cpu" ):
158+ self .inv_freq = 1.0 / (theta ** (torch .arange (0 , dim , 2 ).float () / dim ))
162159
163- def forward (self , seqlen : int ) -> torch .Tensor :
164- seq = torch .arange (seqlen , device = self .inv_freq .device , dtype = self .inv_freq .dtype )
165- freqs = torch .outer (seq , self .inv_freq )
160+ def forward (self , seqlen : int , device : str ) -> torch .Tensor :
161+ inv_freq = self .inv_freq .to (device = device )
162+ seq = torch .arange (seqlen , device = inv_freq .device , dtype = inv_freq .dtype )
163+ freqs = torch .outer (seq , inv_freq )
166164 return freqs
167165
168166
@@ -222,7 +220,7 @@ def forward(
222220 q = rearrange (q , "s n d -> 1 s n d" )
223221 k = rearrange (k , "s n d -> 1 s n d" )
224222 v = rearrange (v , "s n d -> 1 s n d" )
225- out = attention (q , k , v , attn_impl = self .attn_impl , attn_mask = attention_mask )
223+ out = attention_ops . attention (q , k , v , attn_impl = self .attn_impl , attn_mask = attention_mask )
226224 out = rearrange (out , "1 s n d -> s (n d)" )
227225 out = self .proj (out )
228226 return out
@@ -301,7 +299,7 @@ def __init__(self, config: Qwen2_5_VLVisionConfig, device: str = "cuda:0", dtype
301299 dtype = dtype ,
302300 )
303301 head_dim = config .hidden_size // config .num_heads
304- self .rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding (head_dim // 2 , device = device , dtype = dtype )
302+ self .rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding (head_dim // 2 )
305303 self .blocks = nn .ModuleList (
306304 [
307305 Qwen2_5_VisionBlock (
@@ -348,7 +346,7 @@ def rot_pos_emb(self, grid_thw):
348346 pos_ids .append (torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 ))
349347 pos_ids = torch .cat (pos_ids , dim = 0 )
350348 max_grid_size = grid_thw [:, 1 :].max ()
351- rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size )
349+ rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size , device = grid_thw . device )
352350 rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
353351 return rotary_pos_emb
354352
@@ -488,7 +486,6 @@ def __init__(
488486 hidden_size : int = 3584 ,
489487 num_attention_heads : int = 28 ,
490488 num_key_value_heads : int = 4 ,
491- # dropout: float = 0.0,
492489 mrope_section : List [int ] = [16 , 24 , 24 ],
493490 attn_impl : Optional [str ] = None ,
494491 device : str = "cuda:0" ,
@@ -501,7 +498,6 @@ def __init__(
501498 self .head_dim = hidden_size // num_attention_heads
502499 self .num_key_value_heads = num_key_value_heads
503500 self .num_key_value_groups = num_attention_heads // num_key_value_heads
504- # self.dropout = dropout
505501 self .mrope_section = mrope_section
506502 self .attn_impl = attn_impl
507503
@@ -521,8 +517,6 @@ def __init__(
521517 self .num_attention_heads * self .head_dim , self .hidden_size , bias = False , device = device , dtype = dtype
522518 )
523519
524- self .rotary_emb = Qwen2_5_VLRotaryEmbedding (dim = self .head_dim , device = device , dtype = dtype )
525-
526520 def forward (
527521 self ,
528522 hidden_states : torch .Tensor ,
@@ -556,14 +550,18 @@ def forward(
556550 if attention_mask is not None : # no matter the length, we just slice it
557551 causal_mask = attention_mask [:, :, :, : key_states .shape [1 ]]
558552
559- # TODO: attention_mask for flash attention 2
560- out = attention (
561- query_states ,
562- key_states ,
563- value_states ,
564- attn_impl = self .attn_impl ,
565- attn_mask = causal_mask ,
566- )
553+ # TODO: use is_causal when attention mask is causal
554+ if self .attn_impl == "sdpa" :
555+ out = attention_ops .sdpa_attn (query_states , key_states , value_states , is_causal = True )
556+ else :
557+ # TODO: attention_mask for flash attention 2
558+ out = attention_ops .attention (
559+ query_states ,
560+ key_states ,
561+ value_states ,
562+ attn_impl = self .attn_impl ,
563+ attn_mask = causal_mask ,
564+ )
567565 out = rearrange (out , "b s n d -> b s (n d)" )
568566 out = self .o_proj (out )
569567 return out , past_key_values
@@ -647,29 +645,29 @@ def forward(
647645
648646
649647class Qwen2_5_VLRotaryEmbedding (nn .Module ):
650- def __init__ (self , dim : int = 128 , device : str = "cuda:0" , dtype : torch . dtype = torch . bfloat16 ):
648+ def __init__ (self , dim : int = 128 ):
651649 super ().__init__ ()
652- with torch .device (device ):
653- inv_freq = self .compute_rope (dim ) # default rope without dynamic frequency
654- self .register_buffer ("inv_freq" , inv_freq , persistent = False )
650+ with torch .device ("cpu" ):
651+ self .inv_freq = self .compute_rope (dim ) # default rope without dynamic frequency
655652
656653 def compute_rope (self , dim : int , theta : float = 1000000.0 ):
657654 inv_freq = 1.0 / (theta ** (torch .arange (0 , dim , 2 ).float () / dim ))
658655 return inv_freq
659656
660657 @torch .no_grad ()
661- def forward (self , x , position_ids ):
658+ def forward (self , position_ids : torch . LongTensor , device : str , dtype : torch . dtype ):
662659 # In contrast to other models, Qwen2_5_VL has different position ids for the grids
663660 # So we expand the inv_freq to shape (3, ...)
664- inv_freq_expanded = self .inv_freq [None , None , :, None ].float ().expand (3 , position_ids .shape [1 ], - 1 , 1 )
661+ inv_freq = self .inv_freq .to (device = device )
662+ inv_freq_expanded = inv_freq [None , None , :, None ].float ().expand (3 , position_ids .shape [1 ], - 1 , 1 )
665663 position_ids_expanded = position_ids [:, :, None , :].float () # shape (3, bs, 1, positions)
666664
667- freqs = (inv_freq_expanded . float () @ position_ids_expanded . float () ).transpose (2 , 3 )
665+ freqs = (inv_freq_expanded @ position_ids_expanded ).transpose (2 , 3 )
668666 emb = torch .cat ((freqs , freqs ), dim = - 1 )
669667 cos = emb .cos ()
670668 sin = emb .sin ()
671669
672- return cos .to (device = x . device , dtype = x . dtype ), sin .to (device = x . device , dtype = x . dtype )
670+ return cos .to (device = device , dtype = dtype ), sin .to (device = device , dtype = dtype )
673671
674672
675673class Qwen2_5_VLModel (nn .Module ):
@@ -702,7 +700,7 @@ def __init__(self, config: Qwen2_5_VLConfig, device: str = "cuda:0", dtype: torc
702700 )
703701 self .norm = Qwen2_5_RMSNorm (config .hidden_size , config .rms_norm_eps , device = device , dtype = dtype )
704702 head_dim = config .hidden_size // config .num_attention_heads
705- self .rotary_emb = Qwen2_5_VLRotaryEmbedding (dim = head_dim , device = device , dtype = dtype )
703+ self .rotary_emb = Qwen2_5_VLRotaryEmbedding (dim = head_dim )
706704
707705 def get_input_embeddings (self ):
708706 return self .embed_tokens
@@ -749,7 +747,7 @@ def forward(
749747 hidden_states = inputs_embeds
750748
751749 # create position embeddings to be shared across the decoder layers
752- position_embeddings = self .rotary_emb (hidden_states , position_ids )
750+ position_embeddings = self .rotary_emb (position_ids , device = hidden_states . device , dtype = hidden_states . dtype )
753751
754752 # decoder layers
755753 for decoder_layer in self .layers :
@@ -940,8 +938,7 @@ def from_state_dict(
940938 with torch .device ("meta" ), no_init_weights ():
941939 model = cls (vision_config = vision_config , config = config , device = device , dtype = dtype )
942940 model .load_state_dict (state_dict , assign = True )
943- for param in model .parameters (): # skip buffers
944- param .data = param .data .to (device = device , dtype = dtype , non_blocking = True )
941+ model .to (device = device , dtype = dtype , non_blocking = True )
945942 return model
946943
947944 def get_input_embeddings (self ):
@@ -1202,27 +1199,14 @@ def forward(
12021199 if position_ids is None :
12031200 assert attention_mask is None or attention_mask .ndim == 2 , "attention mask must be 2D"
12041201 # calculate RoPE index once per generation in the pre-fill stage only
1205- if (cache_position is not None and cache_position [0 ] == 0 ) or self .rope_deltas is None :
1206- position_ids , rope_deltas = self .get_rope_index (
1207- input_ids ,
1208- image_grid_thw ,
1209- video_grid_thw ,
1210- second_per_grid_ts ,
1211- attention_mask ,
1212- )
1213- self .rope_deltas = rope_deltas
1214- # then use the prev pre-calculated rope-deltas to get the correct position ids
1215- else :
1216- batch_size , seq_length , _ = inputs_embeds .shape
1217- delta = (
1218- (cache_position [0 ] + self .rope_deltas ).to (inputs_embeds .device ) if cache_position is not None else 0
1219- )
1220- position_ids = torch .arange (seq_length , device = inputs_embeds .device )
1221- position_ids = position_ids .view (1 , - 1 ).expand (batch_size , - 1 )
1222- if cache_position is not None : # otherwise `deltas` is an int `0`
1223- delta = delta .repeat_interleave (batch_size // delta .shape [0 ], dim = 0 )
1224- position_ids = position_ids .add (delta )
1225- position_ids = position_ids .unsqueeze (0 ).expand (3 , - 1 , - 1 )
1202+ position_ids , rope_deltas = self .get_rope_index (
1203+ input_ids ,
1204+ image_grid_thw ,
1205+ video_grid_thw ,
1206+ second_per_grid_ts ,
1207+ attention_mask ,
1208+ )
1209+ self .rope_deltas = rope_deltas
12261210
12271211 hidden_states , present_key_values = self .model (
12281212 input_ids = None ,
0 commit comments