Skip to content

Commit fd1234c

Browse files
ggerganovngxsonslaren
authored
llama : add gpt-oss (#15091)
* oai moe * compat with new checkpoint * add attn sink impl * add rope scaling yarn * logits match with latest transformers code * wip chat template * rm trailing space * use ggml_scale_bias * rm redundant is_swa_all * convert interleaved gate_up * graph : fix activation function to match reference (#7) * vocab : handle o200k_harmony special tokens * ggml : add attention sinks support (#1) * llama : add attn sinks * ggml : add attn sinks * cuda : add attn sinks * vulkan : add support for sinks in softmax remove unnecessary return * ggml : add fused swiglu_oai op (#11) * ggml : add fused swiglu_oai op * Update ggml/src/ggml-cpu/ops.cpp Co-authored-by: Georgi Gerganov <[email protected]> * update CUDA impl * cont : metal impl * add vulkan impl * test-backend-ops : more test cases, clean up * llama : remove unfused impl * remove extra lines --------- Co-authored-by: Georgi Gerganov <[email protected]> --------- Co-authored-by: slaren <[email protected]> * repack mxfp4 upon conversion * clean up a bit * enable thinking * add quick hack to render only some special tokens * fix bf16 conversion * remove vocab hack * webui ok * support chat parsing for gpt-oss * fix webui * direct mapping mxfp4, FINALLY * force using mxfp4 * properly use lazy tensor * ggml : add mxfp4 ggml : use e8m0 conversion instead of powf Co-authored-by: Diego Devesa <[email protected]> change kvalues_mxfp4 table to match e2m1 (#6) metal : remove quantization for now (not used) cuda : fix disabled CUDA graphs due to ffn moe bias vulkan : add support for mxfp4 cont : add cm2 dequant * ggml : add ggml_add_id (#13) * ggml : add ggml_add_id * add cuda impl * llama : add weight support check for add_id * perf opt * add vulkan impl * rename cuda files * add metal impl * allow in-place ggml_add_id * llama : keep biases on CPU with --cpu-moe * llama : fix compile error ggml-ci * cuda : add fallback for __nv_cvt_e8m0_to_bf16raw ggml-ci * cleanup ggml-ci * sycl : fix supports_op for MXFP4 ggml-ci * fix Unknown reasoning format * ggml-cpu : fix AVX build ggml-ci * fix hip build ggml-ci * cuda : add mxfp4 dequantization support for cuBLAS ggml-ci * ggml-cpu : fix mxfp4 fallback definitions for some architectures ggml-ci * cuda : fix version required for __nv_cvt_e8m0_to_bf16raw --------- Co-authored-by: Xuan Son Nguyen <[email protected]> Co-authored-by: slaren <[email protected]>
1 parent f324a3b commit fd1234c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+2943
-228
lines changed

common/arg.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2947,11 +2947,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
29472947
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
29482948
"- none: leaves thoughts unparsed in `message.content`\n"
29492949
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
2950-
"(default: deepseek)",
2950+
"(default: auto)",
29512951
[](common_params & params, const std::string & value) {
29522952
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
29532953
else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; }
29542954
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
2955+
else if (value == "auto") { params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; }
29552956
else { throw std::invalid_argument("invalid value"); }
29562957
}
29572958
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));

common/chat.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ const char * common_chat_format_name(common_chat_format format) {
606606
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
607607
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
608608
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
609+
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
609610
default:
610611
throw std::runtime_error("Unknown chat format");
611612
}
@@ -614,6 +615,7 @@ const char * common_chat_format_name(common_chat_format format) {
614615
const char * common_reasoning_format_name(common_reasoning_format format) {
615616
switch (format) {
616617
case COMMON_REASONING_FORMAT_NONE: return "none";
618+
case COMMON_REASONING_FORMAT_AUTO: return "auto";
617619
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
618620
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
619621
default:
@@ -1303,6 +1305,26 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
13031305
tool_calls_end);
13041306
}
13051307

1308+
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
1309+
common_chat_params data;
1310+
auto prompt = apply(tmpl, inputs);
1311+
1312+
data.prompt = prompt;
1313+
data.format = COMMON_CHAT_FORMAT_GPT_OSS;
1314+
1315+
// TODO: support tool calls in GPT-OSS?
1316+
1317+
return data;
1318+
}
1319+
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
1320+
// TODO @ngxson : this won't work with --special enabled, we should fix that
1321+
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
1322+
if (!builder.syntax().parse_tool_calls) {
1323+
builder.add_content(builder.consume_rest());
1324+
return;
1325+
}
1326+
}
1327+
13061328
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
13071329
LOG_DBG("%s\n", __func__);
13081330
common_chat_params data;
@@ -1788,6 +1810,11 @@ static common_chat_params common_chat_templates_apply_jinja(
17881810
return common_chat_params_init_hermes_2_pro(tmpl, params);
17891811
}
17901812

1813+
// GPT-OSS
1814+
if (src.find("<|channel|>") != std::string::npos && params.json_schema.is_null()) {
1815+
return common_chat_params_init_gpt_oss(tmpl, params);
1816+
}
1817+
17911818
// Use generic handler when mixing tools + JSON schema.
17921819
// TODO: support that mix in handlers below.
17931820
if ((params.tools.is_array() && params.json_schema.is_object())) {
@@ -1939,6 +1966,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
19391966
case COMMON_CHAT_FORMAT_COMMAND_R7B:
19401967
common_chat_parse_command_r7b(builder);
19411968
break;
1969+
case COMMON_CHAT_FORMAT_GPT_OSS:
1970+
common_chat_parse_gpt_oss(builder);
1971+
break;
19421972
default:
19431973
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
19441974
}

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ enum common_chat_format {
109109
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
110110
COMMON_CHAT_FORMAT_HERMES_2_PRO,
111111
COMMON_CHAT_FORMAT_COMMAND_R7B,
112+
COMMON_CHAT_FORMAT_GPT_OSS,
112113

113114
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
114115
};

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ struct common_params_diffusion {
236236

237237
enum common_reasoning_format {
238238
COMMON_REASONING_FORMAT_NONE,
239+
COMMON_REASONING_FORMAT_AUTO,
239240
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
240241
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
241242
};
@@ -394,7 +395,7 @@ struct common_params {
394395
std::string chat_template = ""; // NOLINT
395396
bool use_jinja = false; // NOLINT
396397
bool enable_chat_template = true;
397-
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
398+
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
398399
int reasoning_budget = -1;
399400
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
400401

convert_hf_to_gguf.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7950,6 +7950,119 @@ def set_vocab(self):
79507950
self.gguf_writer.add_chat_template(chat_template)
79517951

79527952

7953+
@ModelBase.register("GptOssForCausalLM")
7954+
class GptOssModel(TextModel):
7955+
model_arch = gguf.MODEL_ARCH.GPT_OSS
7956+
7957+
def transform_nibble_layout(self, tensor):
7958+
assert tensor.dtype == torch.uint8
7959+
assert tensor.shape[-1] == 16
7960+
# swap nibbles
7961+
t_lo = tensor & 0x0F
7962+
t_hi = tensor & 0xF0
7963+
t_swapped = (t_lo << 4) | (t_hi >> 4)
7964+
tensor = t_swapped
7965+
# transform aaaa...bbbb... to abababab...
7966+
blk_a, blk_b = tensor.chunk(2, dim=-1)
7967+
# get a_
7968+
blk_a0 = (blk_a & 0xF0).view(-1, 1)
7969+
blk_a1 = (blk_a << 4).view(-1, 1)
7970+
blk_a = torch.stack((blk_a0, blk_a1), dim=2).view(tensor.shape)
7971+
# get _b
7972+
blk_b0 = (blk_b >> 4).view(-1, 1)
7973+
blk_b1 = (blk_b & 0x0F).view(-1, 1)
7974+
blk_b = torch.stack((blk_b0, blk_b1), dim=2).view(tensor.shape)
7975+
# swap once more
7976+
out = blk_a | blk_b
7977+
out_h = out & 0xF0
7978+
out_l = out & 0x0F
7979+
out = (out_h >> 4) | (out_l << 4)
7980+
return out
7981+
7982+
def repack_mxfp4(self, new_name: str, blocks: Tensor, scales: Tensor):
7983+
assert blocks.dtype == torch.uint8
7984+
assert scales.dtype == torch.uint8
7985+
scales = scales.unsqueeze(-1)
7986+
assert len(blocks.shape) == 4
7987+
assert len(scales.shape) == 4
7988+
blocks = self.transform_nibble_layout(blocks)
7989+
new_data = torch.concat((scales, blocks), dim=-1)
7990+
new_shape = [new_data.shape[0], new_data.shape[1], new_data.shape[2] * 32]
7991+
logger.info(f"Repacked {new_name} with shape {new_shape} and quantization MXFP4")
7992+
# flatten last dim
7993+
new_data = new_data.view(new_data.shape[0], new_data.shape[1], new_data.shape[2] * new_data.shape[3])
7994+
new_data = new_data.numpy()
7995+
self.gguf_writer.add_tensor(new_name, new_data, raw_dtype=gguf.GGMLQuantizationType.MXFP4)
7996+
7997+
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
7998+
blocks0: Tensor = torch.zeros(1)
7999+
blocks1: Tensor = torch.zeros(1)
8000+
found_mxfp4_tensors = False
8001+
# we assume that tensors are loaded in the correct order
8002+
for name, data_torch in self.get_tensors():
8003+
if "mlp.experts.down_proj_blocks" in name:
8004+
blocks0 = data_torch
8005+
elif "mlp.experts.down_proj_scales" in name:
8006+
new_name = self.map_tensor_name(name.replace("_scales", ".weight"))
8007+
self.repack_mxfp4(new_name, blocks0, data_torch)
8008+
found_mxfp4_tensors = True
8009+
elif "mlp.experts.gate_up_proj_blocks" in name:
8010+
blocks0, blocks1 = data_torch[:, ::2, :, :], data_torch[:, 1::2, :, :]
8011+
elif "mlp.experts.gate_up_proj_scales" in name:
8012+
scales0, scales1 = data_torch[:, ::2, :], data_torch[:, 1::2, :]
8013+
new_name_gate = self.map_tensor_name(name.replace("gate_up_proj_scales", "gate_proj.weight"))
8014+
new_name_up = self.map_tensor_name(name.replace("gate_up_proj_scales", "up_proj.weight"))
8015+
self.repack_mxfp4(new_name_gate, blocks0, scales0)
8016+
self.repack_mxfp4(new_name_up, blocks1, scales1)
8017+
found_mxfp4_tensors = True
8018+
if not found_mxfp4_tensors:
8019+
raise ValueError("No MXFP4 tensors found in the model. Please make sure you are using MXFP4 model.")
8020+
return []
8021+
8022+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
8023+
del bid # unused
8024+
8025+
if "sinks" in name:
8026+
name += ".weight"
8027+
8028+
# correct naming for down_proj
8029+
if "down_proj" in name:
8030+
if name.endswith("_bias"):
8031+
name = name.replace("down_proj_bias", "down_proj.bias")
8032+
else:
8033+
return []
8034+
8035+
# split the gate_up into gate and up
8036+
if "gate_up_proj" in name:
8037+
if name.endswith("_bias"):
8038+
name_up = name.replace("gate_up_proj_bias", "up_proj.bias")
8039+
name_gate = name.replace("gate_up_proj_bias", "gate_proj.bias")
8040+
gate_proj_bias, up_proj_bias = data_torch[..., ::2], data_torch[..., 1::2]
8041+
return [
8042+
(self.map_tensor_name(name_gate), gate_proj_bias),
8043+
(self.map_tensor_name(name_up), up_proj_bias)
8044+
]
8045+
else:
8046+
return []
8047+
8048+
return [(self.map_tensor_name(name), data_torch)]
8049+
8050+
def set_vocab(self):
8051+
self._set_vocab_gpt2()
8052+
8053+
def set_gguf_parameters(self):
8054+
super().set_gguf_parameters()
8055+
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
8056+
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size"])
8057+
8058+
rope_scaling = self.hparams.get("rope_scaling") or {}
8059+
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type"))
8060+
assert rope_type == "yarn", f"GPT-OSS only supports yarn rope scaling, got {rope_type}"
8061+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
8062+
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
8063+
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling.get("original_max_position_embeddings", 4096))
8064+
8065+
79538066
@ModelBase.register("Lfm2ForCausalLM")
79548067
@ModelBase.register("LFM2ForCausalLM")
79558068
class LFM2Model(TextModel):
@@ -8089,6 +8202,7 @@ class LazyTorchTensor(gguf.LazyBase):
80898202
_dtype_map: dict[torch.dtype, type] = {
80908203
torch.float16: np.float16,
80918204
torch.float32: np.float32,
8205+
torch.uint8: np.uint8,
80928206
}
80938207

80948208
# used for safetensors slices

ggml/include/ggml.h

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,16 @@
304304
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
305305
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
306306

307+
#define GGML_TENSOR_TERNARY_OP_LOCALS \
308+
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
309+
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
310+
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
311+
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
312+
GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
313+
GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \
314+
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
315+
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
316+
307317
#define GGML_TENSOR_BINARY_OP_LOCALS01 \
308318
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
309319
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
@@ -395,7 +405,8 @@ extern "C" {
395405
// GGML_TYPE_IQ4_NL_4_4 = 36,
396406
// GGML_TYPE_IQ4_NL_4_8 = 37,
397407
// GGML_TYPE_IQ4_NL_8_8 = 38,
398-
GGML_TYPE_COUNT = 39,
408+
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
409+
GGML_TYPE_COUNT = 40,
399410
};
400411

401412
// precision
@@ -430,6 +441,7 @@ extern "C" {
430441
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
431442
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
432443
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
444+
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
433445
};
434446

435447
// available tensor operations:
@@ -438,6 +450,7 @@ extern "C" {
438450

439451
GGML_OP_DUP,
440452
GGML_OP_ADD,
453+
GGML_OP_ADD_ID,
441454
GGML_OP_ADD1,
442455
GGML_OP_ACC,
443456
GGML_OP_SUB,
@@ -557,6 +570,7 @@ extern "C" {
557570
GGML_GLU_OP_REGLU,
558571
GGML_GLU_OP_GEGLU,
559572
GGML_GLU_OP_SWIGLU,
573+
GGML_GLU_OP_SWIGLU_OAI,
560574
GGML_GLU_OP_GEGLU_ERF,
561575
GGML_GLU_OP_GEGLU_QUICK,
562576

@@ -831,6 +845,13 @@ extern "C" {
831845
struct ggml_tensor * b,
832846
enum ggml_type type);
833847

848+
// dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
849+
GGML_API struct ggml_tensor * ggml_add_id(
850+
struct ggml_context * ctx,
851+
struct ggml_tensor * a,
852+
struct ggml_tensor * b,
853+
struct ggml_tensor * ids);
854+
834855
GGML_API struct ggml_tensor * ggml_add1(
835856
struct ggml_context * ctx,
836857
struct ggml_tensor * a,
@@ -1198,6 +1219,13 @@ extern "C" {
11981219
struct ggml_tensor * a,
11991220
struct ggml_tensor * b);
12001221

1222+
GGML_API struct ggml_tensor * ggml_swiglu_oai(
1223+
struct ggml_context * ctx,
1224+
struct ggml_tensor * a,
1225+
struct ggml_tensor * b,
1226+
float alpha,
1227+
float limit);
1228+
12011229
// normalize along rows
12021230
GGML_API struct ggml_tensor * ggml_norm(
12031231
struct ggml_context * ctx,
@@ -1570,6 +1598,10 @@ extern "C" {
15701598
float scale,
15711599
float max_bias);
15721600

1601+
GGML_API void ggml_soft_max_add_sinks(
1602+
struct ggml_tensor * a,
1603+
struct ggml_tensor * sinks);
1604+
15731605
GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
15741606
struct ggml_context * ctx,
15751607
struct ggml_tensor * a,
@@ -2052,6 +2084,10 @@ extern "C" {
20522084
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
20532085
const struct ggml_tensor * a);
20542086

2087+
GGML_API void ggml_flash_attn_ext_add_sinks(
2088+
struct ggml_tensor * a,
2089+
struct ggml_tensor * sinks);
2090+
20552091
// TODO: needs to be adapted to ggml_flash_attn_ext
20562092
GGML_API struct ggml_tensor * ggml_flash_attn_back(
20572093
struct ggml_context * ctx,

ggml/src/ggml-alloc.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
2929
case GGML_OP_DIAG_MASK_ZERO:
3030
case GGML_OP_DIAG_MASK_INF:
3131
case GGML_OP_ADD:
32+
case GGML_OP_ADD_ID:
3233
case GGML_OP_ADD1:
3334
case GGML_OP_SUB:
3435
case GGML_OP_MUL:

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2340,6 +2340,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
23402340
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
23412341
return bias == 0.0f; // TODO: support bias != 0.0f
23422342
case GGML_OP_SOFT_MAX:
2343+
// TODO: support attention sinks [TAG_ATTN_SINKS]
2344+
if (op->src[2]) {
2345+
return false;
2346+
}
23432347
// TODO: support broadcast
23442348
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
23452349
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
@@ -2354,6 +2358,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
23542358
if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
23552359
return false;
23562360
}
2361+
// TODO: support attention sinks [TAG_ATTN_SINKS]
2362+
if (op->src[4]) {
2363+
return false;
2364+
}
23572365
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
23582366
// different head sizes of K and V are not supported yet
23592367
return false;

0 commit comments

Comments
 (0)