Skip to content

Commit 0aa7fa4

Browse files
authored
Implement sliding attention in Gemma3 (#11409)
1 parent 514c24d commit 0aa7fa4

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

comfy/text_encoders/llama.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from dataclasses import dataclass
44
from typing import Optional, Any
55
import math
6-
import logging
76

87
from comfy.ldm.modules.attention import optimized_attention_for_device
98
import comfy.model_management
@@ -177,7 +176,7 @@ class Gemma3_4B_Config:
177176
num_key_value_heads: int = 4
178177
max_position_embeddings: int = 131072
179178
rms_norm_eps: float = 1e-6
180-
rope_theta = [10000.0, 1000000.0]
179+
rope_theta = [1000000.0, 10000.0]
181180
transformer_type: str = "gemma3"
182181
head_dim = 256
183182
rms_norm_add = True
@@ -186,8 +185,8 @@ class Gemma3_4B_Config:
186185
rope_dims = None
187186
q_norm = "gemma3"
188187
k_norm = "gemma3"
189-
sliding_attention = [False, False, False, False, False, 1024]
190-
rope_scale = [1.0, 8.0]
188+
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
189+
rope_scale = [8.0, 1.0]
191190
final_norm: bool = True
192191

193192
class RMSNorm(nn.Module):
@@ -370,7 +369,7 @@ def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: An
370369
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
371370
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
372371

373-
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
372+
if config.sliding_attention is not None:
374373
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
375374
else:
376375
self.sliding_attention = False
@@ -387,7 +386,12 @@ def forward(
387386
if self.transformer_type == 'gemma3':
388387
if self.sliding_attention:
389388
if x.shape[1] > self.sliding_attention:
390-
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
389+
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
390+
sliding_mask.tril_(diagonal=-self.sliding_attention)
391+
if attention_mask is not None:
392+
attention_mask = attention_mask + sliding_mask
393+
else:
394+
attention_mask = sliding_mask
391395
freqs_cis = freqs_cis[1]
392396
else:
393397
freqs_cis = freqs_cis[0]

0 commit comments

Comments
 (0)