33from dataclasses import dataclass
44from typing import Optional , Any
55import math
6- import logging
76
87from comfy .ldm .modules .attention import optimized_attention_for_device
98import comfy .model_management
@@ -177,7 +176,7 @@ class Gemma3_4B_Config:
177176 num_key_value_heads : int = 4
178177 max_position_embeddings : int = 131072
179178 rms_norm_eps : float = 1e-6
180- rope_theta = [10000 .0 , 1000000 .0 ]
179+ rope_theta = [1000000 .0 , 10000 .0 ]
181180 transformer_type : str = "gemma3"
182181 head_dim = 256
183182 rms_norm_add = True
@@ -186,8 +185,8 @@ class Gemma3_4B_Config:
186185 rope_dims = None
187186 q_norm = "gemma3"
188187 k_norm = "gemma3"
189- sliding_attention = [False , False , False , False , False , 1024 ]
190- rope_scale = [1 .0 , 8 .0 ]
188+ sliding_attention = [1024 , 1024 , 1024 , 1024 , 1024 , False ]
189+ rope_scale = [8 .0 , 1 .0 ]
191190 final_norm : bool = True
192191
193192class RMSNorm (nn .Module ):
@@ -370,7 +369,7 @@ def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: An
370369 self .pre_feedforward_layernorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps , add = config .rms_norm_add , device = device , dtype = dtype )
371370 self .post_feedforward_layernorm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps , add = config .rms_norm_add , device = device , dtype = dtype )
372371
373- if config .sliding_attention is not None : # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
372+ if config .sliding_attention is not None :
374373 self .sliding_attention = config .sliding_attention [index % len (config .sliding_attention )]
375374 else :
376375 self .sliding_attention = False
@@ -387,7 +386,12 @@ def forward(
387386 if self .transformer_type == 'gemma3' :
388387 if self .sliding_attention :
389388 if x .shape [1 ] > self .sliding_attention :
390- logging .warning ("Warning: sliding attention not implemented, results may be incorrect" )
389+ sliding_mask = torch .full ((x .shape [1 ], x .shape [1 ]), float ("-inf" ), device = x .device , dtype = x .dtype )
390+ sliding_mask .tril_ (diagonal = - self .sliding_attention )
391+ if attention_mask is not None :
392+ attention_mask = attention_mask + sliding_mask
393+ else :
394+ attention_mask = sliding_mask
391395 freqs_cis = freqs_cis [1 ]
392396 else :
393397 freqs_cis = freqs_cis [0 ]
0 commit comments