Skip to content

Commit db7e748

Browse files
committed
Add Gemma 3 support
1 parent 4644982 commit db7e748

File tree

2 files changed

+169
-0
lines changed

2 files changed

+169
-0
lines changed

fastchat/model/model_adapter.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
LlamaTokenizer,
2424
LlamaForCausalLM,
2525
T5Tokenizer,
26+
Gemma3ForCausalLM,
2627
)
2728

2829
from fastchat.constants import CPU_ISA
@@ -36,6 +37,7 @@
3637
from fastchat.model.model_exllama import generate_stream_exllama
3738
from fastchat.model.model_xfastertransformer import generate_stream_xft
3839
from fastchat.model.model_cllm import generate_stream_cllm
40+
from fastchat.model.model_gemma3 import generate_stream_gemma3
3941

4042
from fastchat.model.monkey_patch_non_inplace import (
4143
replace_llama_attn_with_non_inplace_operations,
@@ -419,6 +421,7 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str):
419421
is_xft = "xft" in model_type
420422
is_yuan = "yuan" in model_type
421423
is_cllm = "consistency-llm" in model_path.lower()
424+
is_gemma3 = "gemma-3" in model_path.lower()
422425

423426
if is_chatglm:
424427
return generate_stream_chatglm
@@ -434,6 +437,8 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str):
434437
return generate_stream_yuan2
435438
elif is_cllm:
436439
return generate_stream_cllm
440+
elif is_gemma3:
441+
return generate_stream_gemma3
437442

438443
elif peft_share_base_weights and is_peft:
439444
# Return a curried stream function that loads the right adapter
@@ -458,6 +463,7 @@ def generate_stream_peft(
458463
is_xft = "xft" in base_model_type
459464
is_yuan = "yuan" in base_model_type
460465
is_cllm = "consistency-llm" in model_path.lower()
466+
is_gemma3 = "gemma-3" in model_path.lower()
461467

462468
generate_stream_function = generate_stream
463469
if is_chatglm:
@@ -474,6 +480,8 @@ def generate_stream_peft(
474480
generate_stream_function = generate_stream_yuan2
475481
elif is_cllm:
476482
generate_stream_function = generate_stream_cllm
483+
elif is_gemma3:
484+
generate_stream_function = generate_stream_gemma3
477485
for x in generate_stream_function(
478486
model,
479487
tokenizer,
@@ -822,6 +830,31 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
822830
)
823831
return model, tokenizer
824832

833+
class Gemma3Adapter(BaseModelAdapter):
834+
"""The model adapter for google/gemma-3"""
835+
836+
def match(self, model_path: str):
837+
return "gemma-3" in model_path.lower()
838+
839+
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
840+
revision = from_pretrained_kwargs.get("revision", "main")
841+
device_map = from_pretrained_kwargs.get("device_map", None)
842+
if device_map == "sequential":
843+
device_map = "auto"
844+
# print("From pretrained kwargs", from_pretrained_kwargs)
845+
tokenizer = AutoTokenizer.from_pretrained(model_path, revision=revision)
846+
model = Gemma3ForCausalLM.from_pretrained(
847+
model_path,
848+
revision=revision,
849+
torch_dtype=torch.bfloat16,
850+
device_map=device_map,
851+
)
852+
return model, tokenizer
853+
854+
855+
def get_default_conv_template(self, model_path: str) -> Conversation:
856+
return get_conv_template("gemma")
857+
825858

826859
class KoalaAdapter(BaseModelAdapter):
827860
"""The model adapter for Koala"""
@@ -2505,8 +2538,12 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
25052538
return get_conv_template("api_based_default")
25062539

25072540

2541+
2542+
2543+
25082544
# Note: the registration order matters.
25092545
# The one registered earlier has a higher matching priority.
2546+
register_model_adapter(Gemma3Adapter)
25102547
register_model_adapter(PeftModelAdapter)
25112548
register_model_adapter(StableVicunaAdapter)
25122549
register_model_adapter(VicunaAdapter)

fastchat/model/model_gemma3.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from threading import Thread
2+
import gc
3+
import torch
4+
from transformers import TextIteratorStreamer
5+
6+
def generate_stream_gemma3(
7+
model,
8+
tokenizer,
9+
params,
10+
device,
11+
context_len,
12+
stream_interval=2,
13+
judge_sent_end=False
14+
):
15+
"""Custom generate stream function for Gemma-3 models"""
16+
# Get parameters from the request
17+
prompt = params.get("prompt", "")
18+
messages = params.get("messages", None)
19+
temperature = float(params.get("temperature", 1.0))
20+
repetition_penalty = float(params.get("repetition_penalty", 1.0))
21+
top_p = float(params.get("top_p", 1.0))
22+
top_k = int(params.get("top_k", -1)) # -1 means disable
23+
max_new_tokens = int(params.get("max_new_tokens", 256))
24+
echo = bool(params.get("echo", True))
25+
stop_str = params.get("stop", None)
26+
stop_token_ids = params.get("stop_token_ids", None) or []
27+
model_name = params.get("model", None)
28+
29+
if tokenizer.eos_token_id not in stop_token_ids:
30+
stop_token_ids.append(tokenizer.eos_token_id)
31+
32+
is_base_model = "pt" in model_name.lower() or "base" in model_name.lower()
33+
34+
if not is_base_model:
35+
# Format input based on whether we have messages or a plain prompt
36+
if messages:
37+
inputs = tokenizer.apply_chat_template(
38+
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
39+
).to(model.device)
40+
else:
41+
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
42+
inputs = tokenizer.apply_chat_template(
43+
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
44+
).to(model.device)
45+
else:
46+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
47+
48+
input_ids = inputs["input_ids"]
49+
input_echo_len = input_ids.shape[1]
50+
51+
# Configure generation parameters
52+
generate_kwargs = {
53+
"max_new_tokens": max_new_tokens,
54+
"do_sample": temperature > 0.0,
55+
"temperature": temperature if temperature > 0.0 else 1.0,
56+
}
57+
58+
if top_p < 1.0:
59+
generate_kwargs["top_p"] = top_p
60+
if top_k > 0:
61+
generate_kwargs["top_k"] = top_k
62+
if repetition_penalty > 1.0:
63+
generate_kwargs["repetition_penalty"] = repetition_penalty
64+
65+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=not echo, skip_special_tokens=True)
66+
generate_kwargs["streamer"] = streamer
67+
68+
# Start generation in a separate thread
69+
thread = Thread(target=lambda: model.generate(input_ids=input_ids, **generate_kwargs))
70+
thread.start()
71+
72+
# Track generation progress
73+
generated_tokens = 0
74+
output_text = ""
75+
76+
# Stream tokens
77+
for new_text in streamer:
78+
output_text += new_text
79+
generated_tokens += 1
80+
81+
# Check for stop strings
82+
should_stop = False
83+
if stop_str:
84+
if isinstance(stop_str, str):
85+
if stop_str in output_text:
86+
output_text = output_text[: output_text.find(stop_str)]
87+
should_stop = True
88+
elif isinstance(stop_str, list):
89+
for stop in stop_str:
90+
if stop in output_text:
91+
output_text = output_text[: output_text.find(stop)]
92+
should_stop = True
93+
break
94+
95+
# Stream at intervals or when stopping
96+
if generated_tokens % stream_interval == 0 or should_stop:
97+
yield {
98+
"text": output_text,
99+
"usage": {
100+
"prompt_tokens": input_echo_len,
101+
"completion_tokens": generated_tokens,
102+
"total_tokens": input_echo_len + generated_tokens,
103+
},
104+
"finish_reason": "stop" if should_stop else None,
105+
}
106+
107+
if should_stop:
108+
break
109+
110+
# Final output with finish reason
111+
if thread.is_alive():
112+
thread.join(
113+
timeout=3600
114+
) # Arbitrary value, but if it doesn't complete in this much time then something is wrong
115+
116+
yield {
117+
"text": output_text,
118+
"usage": {
119+
"prompt_tokens": input_echo_len,
120+
"completion_tokens": generated_tokens,
121+
"total_tokens": input_echo_len + generated_tokens,
122+
},
123+
"finish_reason": "length",
124+
}
125+
126+
# Clean up
127+
gc.collect()
128+
torch.cuda.empty_cache()
129+
if device == "xpu":
130+
torch.xpu.empty_cache()
131+
if device == "npu":
132+
torch.npu.empty_cache()

0 commit comments

Comments
 (0)