Skip to content

Commit 7c38d8f

Browse files
ArthurZuckerqgallouedecSunMarcedbeechingVaibhavs10
authored
Add GPT OSS model from OpenAI (#39923)
* fix * nice * where i am at * Bro this works * Update src/transformers/integrations/tensor_parallel.py * cleanups * yups that was breaking * Update src/transformers/models/openai_moe/modeling_openai_moe.py * gather on experts and not mlp * add changes for latest convert branch * adds options to get output_router_logits from config * bring chat temlate + special tokens back into the script. * initial commmit * update * working with shards * add model.safetensors.index.json * fix * fix * mxfp4 flag * rm print * Fix PAD/EOS/BOS (#18) * fix pad/eos/bos * base model maybe one day * add some doc * special tokens based on harmony. * add in tokenizer config as well. * prepare for rebase with main * Fix for initialize_tensor_parallelism now returning 4-tuple ``` [rank0]: File "/fsx/edward/work/openai-tsm-examples/examples/generate.py", line 17, in <module> [rank0]: model = AutoModelForCausalLM.from_pretrained( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/edward/work/new-model-addition-openai/src/transformers/models/auto/auto_factory.py", line 600, in from_pretrained [rank0]: return model_class.from_pretrained( [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/edward/work/new-model-addition-openai/src/transformers/modeling_utils.py", line 316, in _wrapper [rank0]: return func(*args, **kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/fsx/edward/work/new-model-addition-openai/src/transformers/modeling_utils.py", line 4748, in from_pretrained [rank0]: tp_plan, device_map, device_mesh = initialize_tensor_parallelism(tp_plan, tp_size=None) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: ValueError: too many values to unpack (expected 3) ``` * mxfp4 * mxfp4 draft * fix * fix import * draft * draft impl * finally working ! * simplify * add import * working version * consider blocks and scales * device mesh fix * initial commit * add working dequant + quant logic * update * non nan, gibberish output * working EP + quantization finally ! * start cleaning * remove reversing process * style * some cleaning * initial commmit * more cleaning * more cleaning * simplify * more cleaning * rm duplicated function * changing tp_plan * update tp plan check * add loading attribute * dequantizing logic * use subfunctions * import cleaning * update_param_name * adds clamped swiglu * add clamping to training path * simplify dequant logic * update * Bad merge * more simplifications & tests * fix ! * fix registering custom attention * fix order * fixes * some test nits * nits * nit * fix * Clamp sink logits * Clean * Soft-max trick * Clean up * p * fix deepspeed * update both modeling and modular for cleanup * contiguous * update tests * fix top_k router call * revert renaming * test nits * small fixes for EP * fix path for our local tests * update as I should not have broken that! * fix the loss of mixtral * revert part of the changes related to router_scores, kernel probably no ready for that! * deleting a small nit * update arch * fix post processing * update * running version but not expected output * moving to cuda * initial commit * revert * erroring when loading on cpu * updates * del blocks, scales * fix * style * rm comm * comment * add comment * style * remove duplicated lines * Fix minor issue with weight_map conversion script * fix sampling params * rename to final name * upate pre-final version of template * Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py * fix batched inference * serve fixes * swizzle ! * update final chat template by Matt. * fix responses; pin oai * sinplify * Thanks Matt for his tireless efforts! Co-authored-by: Rocketknight1 <[email protected]> * Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py Co-authored-by: Matt <[email protected]> * fix * Use ROCm kernels from HUB * Make kernel modes explicit * update final chat template by Matt. x2 * Thanks Matt for his tireless efforts! Co-authored-by: Rocketknight1 <[email protected]> * Fix installation * Update setup.py Co-authored-by: Ákos Hadnagy <[email protected]> * allow no content * fix: update message handling in write_tokenizer function * Fix template logic for user message role * last nits for CB and flash_paged! * there was one bad merge * fix CB (hardcode for now, its just using kv groups instead) * fix * better fix for device_map * minor device fix * Fix flash paged * updates * Revert "remove dtensors, not explicit (#39840)" This reverts commit 6dfd561. * update * Revert "remove dtensors, not explicit (#39840)" This reverts commit 6dfd561. * fix merge * fix * Fix line break when custom model indentity * nits testing * to locals first and pass sliding window to flash paged * register modes for MegaBlocksMoeMlp * add integration test in fixtures -> now update the tests to use it! * update integration tests * initial fix * style and update tests * fix * chore(gpt oss): remove mlp_bias from configuration It was just a leftover. * stats * Integration tests * whoops * Shouldn't move model * Ensure assistant messages without thinking always go to "final" channel * More checks to ensure expected format * Add pad_token_id to model configuration in write_model function (#51) * Add oai fix fast tests (#59) * Fix some fast tests * Force some updates * Remove unnecessary fixes * Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py Co-authored-by: Quentin Gallouédec <[email protected]> * Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py Co-authored-by: Quentin Gallouédec <[email protected]> * Update src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py * reasoning -> Reasoning * Add additional integration tests * fixup * Slight fixes * align chat template with harmony * simplify * Add comment * torch testing assert close * torch testing assert close * torch testing assert close * torch testing assert close * torch testing assert close * torch testing assert close * Revert fixup * skip 2 test remove todo * merge * padding side should be left for integration tests * fix modular wrt to changes made to modeling * style * isort * fix opies for the loss * mmmm --------- Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Marc Sun <[email protected]> Co-authored-by: edbeeching <[email protected]> Co-authored-by: Vaibhavs10 <[email protected]> Co-authored-by: MekkCyber <[email protected]> Co-authored-by: Marc Sun <[email protected]> Co-authored-by: Edward Beeching <[email protected]> Co-authored-by: Mohamed Mekkouri <[email protected]> Co-authored-by: Lewis Tunstall <[email protected]> Co-authored-by: Zhuohan Li <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: [email protected] <[email protected]> Co-authored-by: Rocketknight1 <[email protected]> Co-authored-by: Joao Gante <[email protected]> Co-authored-by: Akos Hadnagy <[email protected]> Co-authored-by: Ákos Hadnagy <[email protected]> Co-authored-by: Alvaro Moran <[email protected]> Co-authored-by: Lysandre <[email protected]> Co-authored-by: Matt <[email protected]>
1 parent 738c1a3 commit 7c38d8f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+4668
-98
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,8 @@
617617
title: OLMoE
618618
- local: model_doc/open-llama
619619
title: Open-Llama
620+
- local: model_doc/openai_moe
621+
title: OpenAIMoe
620622
- local: model_doc/opt
621623
title: OPT
622624
- local: model_doc/pegasus
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
<div style="float: right;">
18+
<div class="flex flex-wrap space-x-1">
19+
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
20+
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=
21+
">
22+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
23+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
24+
</div>
25+
</div>
26+
27+
# OpenAIMoE
28+
29+
## Overview
30+
31+
The OpenAIMoE model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
32+
<INSERT SHORT SUMMARY HERE>
33+
34+
The abstract from the paper is the following:
35+
36+
*<INSERT PAPER ABSTRACT HERE>*
37+
38+
Tips:
39+
40+
<INSERT TIPS ABOUT MODEL HERE>
41+
42+
This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
43+
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).
44+
45+
46+
## OpenAIMoeConfig
47+
48+
[[autodoc]] OpenAIMoeConfig
49+
50+
## OpenAIMoeModel
51+
52+
[[autodoc]] OpenAIMoeModel
53+
- forward
54+
55+
## OpenAIMoeForCausalLM
56+
57+
[[autodoc]] OpenAIMoeForCausalLM
58+
- forward

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
129129
"keras>2.9,<2.16",
130130
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
131-
"kernels>=0.6.1,<0.7",
131+
"kernels>=0.6.1,<=0.9",
132132
"librosa",
133133
"natten>=0.14.6,<0.15.0",
134134
"nltk<=3.8.1",
@@ -137,7 +137,7 @@
137137
"onnxconverter-common",
138138
"onnxruntime-tools>=1.4.2",
139139
"onnxruntime>=1.4.0",
140-
"openai",
140+
"openai>=1.98.0",
141141
"opencv-python",
142142
"optimum-benchmark>=0.3.0",
143143
"optuna",

src/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@
277277
"GPTQConfig",
278278
"HiggsConfig",
279279
"HqqConfig",
280+
"Mxfp4Config",
280281
"QuantoConfig",
281282
"QuarkConfig",
282283
"FPQuantConfig",

src/transformers/commands/serving.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -909,7 +909,16 @@ def generate_chat_completion(self, req: dict) -> Generator[str, None, None]:
909909
inputs = inputs.to(model.device)
910910
request_id = req.get("request_id", "req_0")
911911

912-
generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
912+
# Temporary hack for GPTOSS 1: don't filter special tokens
913+
skip_special_tokens = True
914+
if "gptoss" in model.config.architectures[0].lower():
915+
skip_special_tokens = False
916+
917+
generation_streamer = TextIteratorStreamer(
918+
processor,
919+
skip_special_tokens=skip_special_tokens,
920+
skip_prompt=True,
921+
)
913922
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
914923

915924
last_kv_cache = None
@@ -925,12 +934,21 @@ def generate_chat_completion(self, req: dict) -> Generator[str, None, None]:
925934
}
926935

927936
def stream_chat_completion(streamer, _request_id):
937+
# Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output
938+
# classes and piping the reasoning trace into a new field
939+
filter_cot = False
940+
cot_trace_end = None
941+
if "gptoss" in model.config.architectures[0].lower():
942+
filter_cot = True
943+
cot_trace_end = "<|channel|>final<|message|>"
944+
928945
# Thin wrapper to save the KV cache after generation
929946
def generate_with_cache(**kwargs):
930947
generate_output = model.generate(**kwargs)
931948
self.last_kv_cache = generate_output.past_key_values
932949

933950
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
951+
results = ""
934952

935953
try:
936954
thread.start()
@@ -941,6 +959,20 @@ def generate_with_cache(**kwargs):
941959
yield self.build_chat_completion_chunk(request_id, role="assistant", model=model_id_and_revision)
942960

943961
for result in streamer:
962+
# Temporary hack for GPTOS 3: don't emit the final "<|return|>"
963+
if "gptoss" in model.config.architectures[0].lower():
964+
if result.endswith("<|return|>"):
965+
result = result[: -len("<|return|>")]
966+
results += result
967+
968+
# (related to temporary hack 2)
969+
if filter_cot:
970+
if cot_trace_end in results: # end of reasoning trace observed -> stop filtering
971+
filter_cot = False
972+
continue
973+
else:
974+
continue
975+
944976
# ====== TOOL CALL LOGIC ======
945977
if tool_model_family is not None:
946978
# Start of a tool call: reset state variables, set `inside_tool_call`
@@ -1064,7 +1096,16 @@ def generate_response(self, req: dict) -> Generator[str, None, None]:
10641096
inputs = inputs.to(model.device)
10651097
request_id = req.get("previous_response_id", "req_0")
10661098

1067-
generation_streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
1099+
# Temporary hack for GPTOSS 1: don't filter special tokens
1100+
skip_special_tokens = True
1101+
if "gptoss" in model.config.architectures[0].lower():
1102+
skip_special_tokens = False
1103+
1104+
generation_streamer = TextIteratorStreamer(
1105+
processor,
1106+
skip_special_tokens=skip_special_tokens,
1107+
skip_prompt=True,
1108+
)
10681109
generation_config = create_generation_config_from_req(req, model_generation_config=model.generation_config)
10691110

10701111
last_kv_cache = None
@@ -1081,6 +1122,14 @@ def generate_response(self, req: dict) -> Generator[str, None, None]:
10811122
}
10821123

10831124
def stream_response(streamer, _request_id):
1125+
# Temporary hack for GPTOS 2: filter out the CoT tokens. Full solution here implies defining new output
1126+
# classes and piping the reasoning trace into a new field
1127+
filter_cot = False
1128+
cot_trace_end = None
1129+
if "gptoss" in model.config.architectures[0].lower():
1130+
filter_cot = True
1131+
cot_trace_end = "<|channel|>final<|message|>"
1132+
10841133
# Thin wrapper to save the KV cache after generation
10851134
def generate_with_cache(**kwargs):
10861135
generate_output = model.generate(**kwargs)
@@ -1167,14 +1216,29 @@ def generate_with_cache(**kwargs):
11671216
# Stream the actual generated text
11681217
results = ""
11691218
for result in streamer:
1219+
# Temporary hack for GPTOS 3: don't emit the final "<|return|>"
1220+
if "gptoss" in model.config.architectures[0].lower():
1221+
if result.endswith("<|return|>"):
1222+
result = result[: -len("<|return|>")]
11701223
results += result
1224+
1225+
# (related to temporary hack 2)
1226+
if filter_cot:
1227+
if cot_trace_end in results: # end of reasoning trace observed -> stop filtering
1228+
filter_cot = False
1229+
results = "" # reset the results -> results will now track the final response
1230+
continue
1231+
else:
1232+
continue
1233+
11711234
response_output_text_delta = ResponseTextDeltaEvent(
11721235
type="response.output_text.delta",
11731236
item_id=f"msg_{request_id}",
11741237
sequence_number=sequence_number,
11751238
output_index=output_index,
11761239
content_index=content_index,
11771240
delta=result,
1241+
logprobs=[{"token": "", "logprob": 99.9}], # TODO: add actual logprobs
11781242
)
11791243
sequence_number += 1
11801244
yield self.build_response_event(response_output_text_delta)
@@ -1187,6 +1251,7 @@ def generate_with_cache(**kwargs):
11871251
output_index=output_index,
11881252
content_index=0,
11891253
text=results,
1254+
logprobs=[{"token": "", "logprob": 99.9}], # TODO: add actual logprobs
11901255
)
11911256
sequence_number += 1
11921257
yield self.build_response_event(response_output_text_done)
@@ -1446,9 +1511,10 @@ def _load_model_and_data_processor(self, model_id_and_revision: str):
14461511
"attn_implementation": args.attn_implementation,
14471512
"torch_dtype": torch_dtype,
14481513
"device_map": "auto",
1449-
"quantization_config": quantization_config,
14501514
"trust_remote_code": args.trust_remote_code,
14511515
}
1516+
if quantization_config is not None:
1517+
model_kwargs["quantization_config"] = quantization_config
14521518

14531519
config = AutoConfig.from_pretrained(model_id, **model_kwargs)
14541520
architecture = getattr(transformers, config.architectures[0])

src/transformers/dependency_versions_table.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
"kenlm": "kenlm",
3535
"keras": "keras>2.9,<2.16",
3636
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
37-
"kernels": "kernels>=0.6.1,<0.7",
37+
"kernels": "kernels>=0.6.1,<=0.9",
3838
"librosa": "librosa",
3939
"natten": "natten>=0.14.6,<0.15.0",
4040
"nltk": "nltk<=3.8.1",
@@ -43,7 +43,7 @@
4343
"onnxconverter-common": "onnxconverter-common",
4444
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
4545
"onnxruntime": "onnxruntime>=1.4.0",
46-
"openai": "openai",
46+
"openai": "openai>=1.98.0",
4747
"opencv-python": "opencv-python",
4848
"optimum-benchmark": "optimum-benchmark>=0.3.0",
4949
"optuna": "optuna",

src/transformers/generation/continuous_batching.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -182,27 +182,29 @@ def __init__(
182182
f"Number of key value heads {num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
183183
)
184184
# If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
185-
self.num_key_value_heads //= tp_size
185+
# self.num_key_value_heads //= tp_size
186186

187187
self.head_dim = (
188188
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
189189
)
190190
self.num_hidden_layers = config.num_hidden_layers
191191

192192
# Calculate optimal block size and number if not provided
193-
num_blocks = getattr(generation_config, "num_blocks", None)
193+
num_blocks = getattr(generation_config, "num_blocks", 1024)
194194
block_size = getattr(generation_config, "block_size", 32)
195195
max_memory_percent = getattr(generation_config, "max_memory", 0.9)
196-
num_blocks, max_batch_tokens = compute_optimal_blocks(
197-
generation_config.max_new_tokens,
198-
block_size=block_size,
199-
head_dim=self.head_dim,
200-
num_layers=self.num_hidden_layers,
201-
num_heads=self.num_key_value_heads,
202-
max_memory_percent=max_memory_percent,
203-
dtype=dtype,
204-
num_blocks=num_blocks,
205-
)
196+
max_batch_tokens = getattr(generation_config, "max_batch_tokens", 256)
197+
if num_blocks is None or max_batch_tokens is None:
198+
num_blocks, max_batch_tokens = compute_optimal_blocks(
199+
generation_config.max_new_tokens,
200+
block_size=block_size,
201+
head_dim=self.head_dim,
202+
num_layers=self.num_hidden_layers,
203+
num_heads=self.num_key_value_heads,
204+
max_memory_percent=max_memory_percent,
205+
dtype=dtype,
206+
num_blocks=num_blocks,
207+
)
206208
logger.warning(
207209
f"Using calculated num_blocks={num_blocks}, block_size={block_size}, max concurrent requests {max_batch_tokens}"
208210
)
@@ -960,7 +962,14 @@ def _build_tensors(
960962

961963
@traced
962964
def _sync(self):
963-
return self.output_ids.tolist()[0] # should be the only synch we do
965+
if self.output_ids is not None:
966+
try:
967+
out = self.output_ids.tolist()[0] # should be the only synch we do
968+
except Exception:
969+
out = [0, 1]
970+
else:
971+
out = [0, 0]
972+
return out
964973

965974
@traced
966975
def _maybe_send_output(self, state: RequestState, token: int):
@@ -1250,7 +1259,7 @@ def _run_generation_loop(self):
12501259
self.model.device,
12511260
self.model.dtype,
12521261
num_requests=len(self.input_queue.queue),
1253-
tp_size=getattr(self.model, "tp_size"),
1262+
tp_size=getattr(self.model, "_tp_size", 8), # TODO quantized converted don't set this
12541263
)
12551264

12561265
scheduler = None

src/transformers/integrations/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,14 @@
119119
"run_hp_search_sigopt",
120120
"run_hp_search_wandb",
121121
],
122+
"mxfp4": [
123+
"Mxfp4GptOssExperts",
124+
"convert_moe_packed_tensors",
125+
"dequantize",
126+
"load_and_swizzle_mxfp4",
127+
"quantize_to_mxfp4",
128+
"replace_with_mxfp4_linear",
129+
],
122130
"peft": ["PeftAdapterMixin"],
123131
"quanto": ["replace_with_quanto_layers"],
124132
"spqr": ["replace_with_spqr_linear"],
@@ -255,6 +263,13 @@
255263
run_hp_search_sigopt,
256264
run_hp_search_wandb,
257265
)
266+
from .mxfp4 import (
267+
Mxfp4GptOssExperts,
268+
dequantize,
269+
load_and_swizzle_mxfp4,
270+
quantize_to_mxfp4,
271+
replace_with_mxfp4_linear,
272+
)
258273
from .peft import PeftAdapterMixin
259274
from .quanto import replace_with_quanto_layers
260275
from .spqr import replace_with_spqr_linear

src/transformers/integrations/flash_paged.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ def paged_attention_forward(
5050
"""
5151
k, v = cache.update(k, v, module.layer_idx, cumulative_seqlens_k=cumulative_seqlens_k, **kwargs)
5252

53+
sliding_window = (-1, -1) if not getattr(module, "sliding_window", False) else (module.sliding_window, 0)
5354
if implementation is not None:
5455
flash_attn_varlen_func = implementation.flash_attn_varlen_func
56+
custom_kwargs = {"s_aux": kwargs.get("s_aux")}
5557
attn_output = flash_attn_varlen_func(
5658
q.transpose(1, 2).squeeze(0).contiguous(),
5759
k.transpose(1, 2).squeeze(0).contiguous(),
@@ -62,9 +64,9 @@ def paged_attention_forward(
6264
max_seqlen_k,
6365
softmax_scale=module.scaling,
6466
causal=True, # kind of a must, it automatically aligns the mask for q < k
65-
window_size=(-1, -1), # -1 means infinite context window
67+
window_size=sliding_window, # -1 means infinite context window
6668
# block_table=block_tables, -> torch.Tensor
67-
# **kwargs,
69+
**custom_kwargs,
6870
)
6971
if isinstance(attn_output, tuple):
7072
attn_output = attn_output[0]

0 commit comments

Comments
 (0)