Skip to content

Commit a5ac374

Browse files
feat: support linear scaled rope for tgis_native llama (#61)
Implements a new LinearScalingPositionRotaryEmbedding layer that supports linear scaling of position ids when processing embeddings. Without this, models with a linear rope_scaling configuration could load fine but would give garbage output. Signed-off-by: Travis Johnson <[email protected]> Co-authored-by: TRAVIS JOHNSON <[email protected]>
1 parent 03db106 commit a5ac374

File tree

2 files changed

+42
-10
lines changed

2 files changed

+42
-10
lines changed

server/text_generation_server/models/custom_modeling/flash_llama_modeling.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
TensorParallelColumnLinear,
3636
TensorParallelEmbedding,
3737
PositionRotaryEmbedding,
38+
LinearScalingPositionRotaryEmbedding,
3839
TensorParallelHead,
3940
get_linear,
4041
)
@@ -183,12 +184,22 @@ def __init__(
183184
self.hidden_size = config.hidden_size
184185
self.head_size = self.hidden_size // self.num_heads
185186

186-
# self.rotary_emb = PositionRotaryEmbedding.load(
187-
# prefix=f"{prefix}.rotary_emb", weights=weights
188-
# )
189-
self.rotary_emb = PositionRotaryEmbedding.static(
190-
dim=self.head_size, base=config.rope_theta, device=weights.device
191-
)
187+
if config.rope_scaling and "type" in config.rope_scaling:
188+
if config.rope_scaling["type"] == "linear":
189+
self.rotary_emb = LinearScalingPositionRotaryEmbedding.static(
190+
dim=self.head_size,
191+
base=config.rope_theta,
192+
scaling_factor=config.rope_scaling.get("factor", 1.0),
193+
device=weights.device
194+
)
195+
else:
196+
raise ValueError(
197+
f"rope_scaling of type f{config.rope_scaling.type} is not supported with FLASH_ATTENTION=True"
198+
)
199+
else:
200+
self.rotary_emb = PositionRotaryEmbedding.static(
201+
dim=self.head_size, base=config.rope_theta, device=weights.device
202+
)
192203

193204
self.softmax_scale = self.head_size**-0.5
194205

server/text_generation_server/utils/layers.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,24 +389,25 @@ def forward(self, hidden_states, residual=None):
389389
from flash_attn.layers.rotary import RotaryEmbedding
390390
import rotary_emb
391391

392-
class PositionRotaryEmbedding(nn.Module):
393-
def __init__(self, inv_freq):
392+
class BasePositionRotaryEmbedding(nn.Module):
393+
def __init__(self, inv_freq, scaling_factor=1.0):
394394
super().__init__()
395395

396396
self.inv_freq = inv_freq
397+
self.scaling_factor = scaling_factor
397398
self._seq_len_cached = 0
398399
self._cos_cached = None
399400
self._sin_cached = None
400401
self._cos_k_cached = None
401402
self._sin_k_cached = None
402403

403404
@classmethod
404-
def static(cls, dim, base, device):
405+
def static(cls, dim, base, device, scaling_factor=1.0):
405406
inv_freq = 1.0 / (
406407
base
407408
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
408409
)
409-
return cls(inv_freq)
410+
return cls(inv_freq, scaling_factor)
410411

411412
@classmethod
412413
def load(cls, prefix, weights):
@@ -427,6 +428,8 @@ def _update_cos_sin_cache(self, dtype, device, seqlen):
427428
):
428429
self._seq_len_cached = seqlen
429430
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
431+
if self.scaling_factor != 1.0:
432+
t = t / self.scaling_factor
430433
# Don't do einsum, it converts fp32 to fp16
431434
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
432435
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
@@ -454,5 +457,23 @@ def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
454457
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
455458
return x
456459

460+
class PositionRotaryEmbedding(BasePositionRotaryEmbedding):
461+
@classmethod
462+
def static(cls, dim, base, device):
463+
inv_freq = 1.0 / (
464+
base
465+
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
466+
)
467+
return cls(inv_freq)
468+
469+
class LinearScalingPositionRotaryEmbedding(BasePositionRotaryEmbedding):
470+
@classmethod
471+
def static(cls, dim, base, scaling_factor, device):
472+
inv_freq = 1.0 / (
473+
base
474+
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
475+
)
476+
return cls(inv_freq, scaling_factor)
477+
457478
except ImportError:
458479
pass

0 commit comments

Comments
 (0)