Skip to content

Commit 1ade8b8

Browse files
authored
Added Gemma (#267)
1 parent b407adc commit 1ade8b8

File tree

10 files changed

+852
-91
lines changed

10 files changed

+852
-91
lines changed

docs/models/adapters.md

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,15 @@ Any combination of linear layers can be targeted in the adapters, which correspo
3636
- `o_proj`
3737
- `lm_head`
3838

39-
### Qwen
39+
### Gemma
4040

41-
- `c_attn`
42-
- `c_proj`
43-
- `w1`
44-
- `w2`
45-
- `lm_head`
41+
- `q_proj`
42+
- `k_proj`
43+
- `v_proj`
44+
- `o_proj`
45+
- `gate_proj`
46+
- `up_proj`
47+
- `down_proj`
4648

4749
### Phi
4850

@@ -54,6 +56,14 @@ Any combination of linear layers can be targeted in the adapters, which correspo
5456
- `fc2`
5557
- `lm_head`
5658

59+
### Qwen
60+
61+
- `c_attn`
62+
- `c_proj`
63+
- `w1`
64+
- `w2`
65+
- `lm_head`
66+
5767
### GPT2
5868

5969
- `c_attn`

docs/models/base_models.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
- 🌬️[Mistral](https://huggingface.co/mistralai)
88
- [Zephyr](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta)
99
- 🔄 [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)
10-
- 🔮 [Qwen](https://huggingface.co/Qwen)
10+
- 💎 [Gemma](https://blog.google/technology/developers/gemma-open-models/)
1111
- 🏛️ [Phi](https://huggingface.co/microsoft/phi-2)
12+
- 🔮 [Qwen](https://huggingface.co/Qwen)
1213
- 🤖 [GPT2](https://huggingface.co/gpt2)
1314
- 🌸 [Bloom](https://huggingface.co/bigscience/bloom)
1415

server/lorax_server/models/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from lorax_server.models.flash_rw import FlashRWSharded
5151
from lorax_server.models.flash_neox import FlashNeoXSharded
5252
from lorax_server.models.flash_llama import FlashLlama
53+
from lorax_server.models.flash_gemma import FlashGemma
5354
from lorax_server.models.flash_gpt2 import FlashGPT2
5455
from lorax_server.models.flash_qwen import FlashQwen
5556
from lorax_server.models.flash_phi import FlashPhi
@@ -66,6 +67,7 @@
6667
__all__.append(FlashRWSharded)
6768
__all__.append(FlashSantacoderSharded)
6869
__all__.append(FlashLlama)
70+
__all__.append(FlashGemma)
6971
__all__.append(FlashGPT2)
7072
__all__.append(FlashQwen)
7173
__all__.append(FlashPhi)
@@ -361,6 +363,20 @@ def get_model(
361363
trust_remote_code=trust_remote_code,
362364
)
363365
raise NotImplementedError("Phi model requires flash attention v2")
366+
367+
if model_type == "gemma":
368+
if FLASH_ATTENTION:
369+
return FlashGemma(
370+
model_id,
371+
adapter_id,
372+
adapter_source,
373+
revision,
374+
quantize=quantize,
375+
compile=compile,
376+
dtype=dtype,
377+
trust_remote_code=trust_remote_code,
378+
)
379+
raise NotImplementedError("Gemma model requires flash attention v2")
364380

365381
if model_type == "opt":
366382
return OPTSharded(

0 commit comments

Comments
 (0)