Skip to content

Commit 1f3c7dd

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents e39f1b2 + 5d195f1 commit 1f3c7dd

File tree

5 files changed

+62
-103
lines changed

5 files changed

+62
-103
lines changed

convert_hf_to_gguf.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,6 +1497,17 @@ def get_audio_config(self) -> dict[str, Any] | None:
14971497
def set_type(self):
14981498
self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
14991499

1500+
def prepare_metadata(self, vocab_only: bool):
1501+
super().prepare_metadata(vocab_only=vocab_only)
1502+
1503+
output_type: str = self.ftype.name.partition("_")[2]
1504+
1505+
if self.fname_out.is_dir():
1506+
fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=output_type, model_type=None)
1507+
self.fname_out = self.fname_out / f"mmproj-{fname_default}.gguf"
1508+
else:
1509+
self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
1510+
15001511
def set_gguf_parameters(self):
15011512
self.gguf_writer.add_file_type(self.ftype)
15021513

@@ -8943,6 +8954,13 @@ def set_vocab(self):
89438954
class GptOssModel(TextModel):
89448955
model_arch = gguf.MODEL_ARCH.GPT_OSS
89458956

8957+
# TODO: remove once MXFP4 is supported more generally
8958+
def dequant_model(self):
8959+
quant_config = self.hparams.get("quantization_config")
8960+
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
8961+
return
8962+
return super().dequant_model()
8963+
89468964
def transform_nibble_layout(self, tensor):
89478965
assert tensor.dtype == torch.uint8
89488966
assert tensor.shape[-1] == 16
@@ -9722,10 +9740,6 @@ def main() -> None:
97229740

97239741
logger.info(f"Loading model: {dir_model.name}")
97249742

9725-
if args.mmproj:
9726-
if "mmproj" not in fname_out.name:
9727-
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
9728-
97299743
is_mistral_format = args.mistral_format
97309744
if is_mistral_format and not _mistral_common_installed:
97319745
raise ImportError(_mistral_import_error_msg)

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 7 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
9696

9797
#define GGML_VK_MAX_NODES 8192
9898

99-
#define MAX_VK_BUFFERS 256
100-
10199
#define VK_CHECK(err, msg) \
102100
do { \
103101
vk::Result err_ = (err); \
@@ -1311,7 +1309,6 @@ struct ggml_vk_garbage_collector {
13111309
std::vector<vk_semaphore> tl_semaphores;
13121310
std::vector<vk_semaphore> semaphores;
13131311
std::vector<vk::Event> events;
1314-
std::vector<vk_buffer> temp_buffers;
13151312
std::vector<vk_context> contexts;
13161313
};
13171314

@@ -1482,8 +1479,6 @@ struct ggml_backend_vk_context {
14821479
// and set to true after the buffer contents are consumed.
14831480
bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
14841481

1485-
vk_buffer buffer_pool[MAX_VK_BUFFERS];
1486-
14871482
vk_context_ref compute_ctx;
14881483
vk_context_ref transfer_ctx;
14891484

@@ -3623,8 +3618,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
36233618

36243619
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
36253620

3626-
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
3627-
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);
3621+
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
3622+
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
3623+
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
3624+
} else {
3625+
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
3626+
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
3627+
}
36283628

36293629
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
36303630

@@ -5144,71 +5144,6 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
51445144
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
51455145
}
51465146

5147-
static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
5148-
VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")");
5149-
VK_LOG_MEMORY("ggml_vk_pool_malloc");
5150-
5151-
int best_i = -1;
5152-
size_t best_size = std::numeric_limits<size_t>::max(); //smallest unused buffer that fits our needs
5153-
int worst_i = -1;
5154-
size_t worst_size = 0; //largest unused buffer seen so far
5155-
for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
5156-
vk_buffer &b = ctx->buffer_pool[i];
5157-
if (b != nullptr && b->size >= size && b->size < best_size) {
5158-
best_i = i;
5159-
best_size = b->size;
5160-
}
5161-
if (b != nullptr && b->size > worst_size) {
5162-
worst_i = i;
5163-
worst_size = b->size;
5164-
}
5165-
}
5166-
if(best_i != -1) {
5167-
//found the smallest buffer that fits our needs
5168-
vk_buffer b = ctx->buffer_pool[best_i];
5169-
ctx->buffer_pool[best_i].reset();
5170-
return b;
5171-
}
5172-
if(worst_i != -1) {
5173-
//no buffer that fits our needs, resize largest one to save memory
5174-
vk_buffer& b = ctx->buffer_pool[worst_i];
5175-
ggml_vk_destroy_buffer(b);
5176-
}
5177-
5178-
return ggml_vk_create_buffer_device(ctx->device, size);
5179-
}
5180-
5181-
static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) {
5182-
VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")");
5183-
for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
5184-
vk_buffer& b = ctx->buffer_pool[i];
5185-
if (b == nullptr) {
5186-
b = buffer;
5187-
return;
5188-
}
5189-
}
5190-
std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl;
5191-
ggml_vk_destroy_buffer(buffer);
5192-
}
5193-
5194-
// Returns an available temporary buffer that may only be used temporarily, it will be reused
5195-
static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) {
5196-
// Try to find existing temp buffer with enough capacity
5197-
for (auto& buffer : ctx->gc.temp_buffers) {
5198-
if (buffer->size >= size) {
5199-
return buffer;
5200-
}
5201-
}
5202-
5203-
VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")");
5204-
5205-
// Otherwise create new buffer
5206-
vk_buffer buf = ggml_vk_pool_malloc(ctx, size);
5207-
ctx->gc.temp_buffers.push_back(buf);
5208-
5209-
return buf;
5210-
}
5211-
52125147
static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
52135148
VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
52145149
vk_buffer buf = ggml_vk_create_buffer(device, size,
@@ -11789,10 +11724,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1178911724
// Clean up after graph processing is done
1179011725
static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
1179111726
VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
11792-
for (auto& buffer : ctx->gc.temp_buffers) {
11793-
ggml_vk_pool_free(ctx, buffer);
11794-
}
11795-
ctx->gc.temp_buffers.clear();
1179611727
ctx->prealloc_y_last_pipeline_used = {};
1179711728

1179811729
ctx->unsynced_nodes_written.clear();
@@ -11835,10 +11766,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
1183511766
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
1183611767
ctx->prealloc_y_last_pipeline_used = nullptr;
1183711768

11838-
for (auto& buffer : ctx->buffer_pool) {
11839-
ggml_vk_destroy_buffer(buffer);
11840-
}
11841-
1184211769
ctx->prealloc_size_x = 0;
1184311770
ctx->prealloc_size_y = 0;
1184411771
ctx->prealloc_size_split_k = 0;

ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#version 450
22

33
#extension GL_EXT_control_flow_attributes : require
4+
#if USE_SUBGROUP_ADD
5+
#extension GL_KHR_shader_subgroup_arithmetic : enable
6+
#endif
47

58
#include "types.glsl"
69

@@ -84,35 +87,47 @@ void main() {
8487
}
8588

8689
barrier();
87-
for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) {
88-
[[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
89-
const uint k = (tid % (w >> 1)) +
90-
(D_STATE * (tid / (w >> 1))) +
91-
j * D_STATE * (D_STATE / (w >> 1));
92-
if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) {
93-
stateC[k] += stateC[k + (w >> 1)];
90+
[[unroll]]
91+
for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
92+
[[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
93+
const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
94+
if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
95+
stateC[k] += stateC[k + w];
9496
}
9597
}
9698
barrier();
9799
}
98100

99-
[[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
101+
[[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
100102
const uint idx = (tid % SUBGROUP_SIZE) +
101103
D_STATE * (tid / SUBGROUP_SIZE) +
102104
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
105+
const uint max_idx = SUBGROUP_SIZE - 1 +
106+
D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
107+
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
103108

104-
uint lane = tid % SUBGROUP_SIZE;
105-
106-
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
107-
if (idx + offset < SPLIT_H * D_STATE) {
108-
stateC[idx] += stateC[idx + offset];
109+
if (idx < SPLIT_H * D_STATE ||
110+
max_idx < SPLIT_H * D_STATE) {
111+
float sc;
112+
#if USE_SUBGROUP_ADD
113+
sc = stateC[idx];
114+
sc = subgroupAdd(sc);
115+
#else
116+
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
117+
if (idx + offset < SPLIT_H * D_STATE) {
118+
stateC[idx] += stateC[idx + offset];
119+
}
120+
barrier();
109121
}
110-
barrier();
111-
}
122+
if (tid % SUBGROUP_SIZE == 0) {
123+
sc = stateC[idx];
124+
}
125+
#endif
112126

113-
if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) {
114-
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
115-
d[y_base_idx + i * stride_y + k] = stateC[idx];
127+
if (tid % SUBGROUP_SIZE == 0) {
128+
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
129+
d[y_base_idx + i * stride_y + k] = sc;
130+
}
116131
}
117132
}
118133

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,8 @@ void process_shaders() {
916916
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
917917
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
918918

919-
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
919+
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
920+
string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
920921

921922
string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});
922923

src/llama-model.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17965,6 +17965,8 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
1796517965
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
1796617966
cb(cur, "result_norm", -1);
1796717967

17968+
res->t_embd = cur;
17969+
1796817970
// lm_head
1796917971
cur = build_lora_mm(model.output, cur);
1797017972
cb(cur, "result_output", -1);

0 commit comments

Comments
 (0)