diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index c81ff03fee810..bff7dea3a539b 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -528,15 +528,15 @@ extern "C" { GGML_UNARY_OP_STEP, GGML_UNARY_OP_TANH, GGML_UNARY_OP_ELU, + GGML_UNARY_OP_RELU, GGML_UNARY_OP_SIGMOID, GGML_UNARY_OP_GELU, - GGML_UNARY_OP_GELU_ERF, GGML_UNARY_OP_GELU_QUICK, GGML_UNARY_OP_SILU, GGML_UNARY_OP_HARDSWISH, GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_EXP, - GGML_UNARY_OP_RELU, + GGML_UNARY_OP_GELU_ERF, GGML_UNARY_OP_COUNT, }; diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index cbf9783b744d1..9c67664af8587 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2697,14 +2697,10 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* } } - // GroupedMatmulV2 required tensor_list.size < 128 size_t GROUP_SIZE = 128; - std::vector> src0_tensor_vec_vec; - std::vector> src1_tensor_vec_vec; - std::vector> dst_tensor_vec_vec; - - // split and call GroupedMatmulV2 + // GroupedMatmulV2 required tensor_list.size < 128 for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) { + // split and call GroupedMatmulV2 size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size()); std::vector src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end); std::vector src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end); @@ -2722,6 +2718,133 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* return; } +/** + * @brief Performs expert-specific matrix multiplication (MoE) with + * quantized precision using the CANN backend. + * + * This function executes a matrix multiplication operation tailored for + * Mixture of Experts (MoE) models, where the input tensor is multiplied + * with expert-specific quantized weight matrices. It leverages the CANN + * backend to perform efficient low-precision computations and stores the + * quantized result in the destination tensor `dst`. + * + * Quantization techniques reduce memory footprint and improve performance + * by using lower-bit representations (e.g., int8) instead of floating-point. + * This function is designed to work with such formats and may incorporate + * optimizations like identity-based fast paths or routing masks for sparse + * expert selection. + * + * @param ctx The context for executing CANN backend operations. + * @param dst The destination tensor where the quantized MoE multiplication result + * will be stored. + * + * @note This function assumes quantized data types and is designed for + * MoE architectures with potential sparse expert routing. + */ +static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + // TODO: Use aclnnGroupedMatMul + //dst [M, K, N, 1] + ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] + ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 + ggml_tensor * ids = dst->src[2]; //ids [K, N] + + GGML_TENSOR_BINARY_OP_LOCALS + + // copy index from npu to cpu + int64_t n_as = ne02; // A + int64_t n_ids = ids->ne[0]; // K + + std::vector ids_host(ggml_nbytes(ids)); + ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids), + ACL_MEMCPY_DEVICE_TO_HOST); + ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + + char * src0_original = (char *) src0->data; + char * src1_original = (char *) src1->data; + char * dst_original = (char *) dst->data; + + ggml_tensor src0_row = *src0; + ggml_tensor src1_row = *src1; + ggml_tensor dst_row = *dst; + + const enum ggml_type type = dst->src[0]->type; + float weight_elem_size; + if (type == GGML_TYPE_Q4_0) { + weight_elem_size = float(sizeof(uint8_t)) / 2; + } else if (type == GGML_TYPE_Q8_0) { + weight_elem_size = float(sizeof(uint8_t)); + } else { + GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 "); + } + + // src0_row [D, M, 1, 1] weight without permute + src0_row.ne[2] = 1; + src0_row.ne[3] = 1; + src0_row.nb[0] = weight_elem_size; + src0_row.nb[1] = weight_elem_size * ne00; + src0_row.nb[2] = weight_elem_size * ne00; + src0_row.nb[3] = weight_elem_size * ne00; + size_t weight_stride = ne00 * ne01 * weight_elem_size; + size_t weight_size = weight_stride * ne02 * ne03; + + // scale [D, M, 1, 1] -> scale && permute + size_t scale_elem_size = sizeof(uint16_t); + size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size; + + // src1_row [D, 1, 1, 1] -> input + src1_row.ne[1] = 1; + src1_row.ne[2] = 1; + src1_row.ne[3] = 1; + src1_row.nb[2] = nb11; + src1_row.nb[3] = nb11; + + // dst_row [M, 1, 1, 1] -> out + dst_row.ne[1] = 1; + dst_row.ne[2] = 1; + dst_row.ne[3] = 1; + dst_row.nb[2] = nb1; + dst_row.nb[3] = nb1; + + //create weight for one row + ggml_cann_pool_alloc weight_allocator(ctx.pool()); + void* weight_buffer = weight_allocator.alloc(nb02); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + // expert index + int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + // If B = 1 (broadcast), always use 0; otherwise, use id. + int64_t i11 = (ne11 == 1 ? 0 : id); + int64_t i12 = iid1; + + int64_t i1 = id; + int64_t i2 = i12; + + void* src0_tmp_ptr = src0_original + i02*weight_stride; + void* scale_tmp_ptr = src0_original + weight_size + i02*scale_stride; + void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12; + void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2; + + // mem cpy + ggml_cann_async_memcpy(ctx, weight_buffer, src0_tmp_ptr, weight_stride, + ACL_MEMCPY_DEVICE_TO_DEVICE); + void* scale_buffer = (char*)weight_buffer + weight_stride; + ggml_cann_async_memcpy(ctx, scale_buffer, scale_tmp_ptr, scale_stride, + ACL_MEMCPY_DEVICE_TO_DEVICE); + + src0_row.data = weight_buffer; + src1_row.data = src1_tmp_ptr; + dst_row.data = dst_tmp_ptr; + dst_row.src[0] = &src0_row; + dst_row.src[1] = &src1_row; + + ggml_cann_mul_mat(ctx, &dst_row); + } + } + return; +} + void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) { const enum ggml_type type = dst->src[0]->type; switch (type) { @@ -2729,6 +2852,10 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) { case GGML_TYPE_F16: ggml_cann_mul_mat_id_fp(ctx, dst); break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + ggml_cann_mul_mat_id_quant(ctx, dst); + break; default: GGML_ABORT("Unsupported type for mul_mat_id"); break; diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 0cb7bbf17cca5..605b6a73c3a13 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2035,6 +2035,15 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_TYPE_F16: case GGML_TYPE_F32: return true; + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: +#ifdef ASCEND_310P + // Q4 && Q8 per group is not suppor on 310p device + return false; +#endif + // only support contiguous for quantized types. + return ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]); default: return false; } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c160a9984075f..cf6dd9d44b8a3 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2804,23 +2804,29 @@ static vk_device ggml_vk_get_device(size_t idx) { pipeline_robustness = true; } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { device->subgroup_size_control = true; +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT")) { device->coopmat_support = true; device->coopmat_m = 0; device->coopmat_n = 0; device->coopmat_k = 0; +#endif +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; +#endif #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { device->integer_dot_product = true; #endif +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_BFLOAT16")) { bfloat16_support = true; +#endif } } @@ -4670,6 +4676,19 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const } } + if (src->type == to) { + // Copy two or four bytes at a time, depending on block size. + // For quantized types, we scale by block size/type size. But + // this path is also used for bf16->bf16 for example, where the + // type size must be exactly 2 or 4. + GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4); + if ((ggml_type_size(src->type) % 4) == 0) { + return ctx->device->pipeline_contig_cpy_f32_f32; + } else { + return ctx->device->pipeline_contig_cpy_f16_f16; + } + } + std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl; GGML_ABORT("fatal error"); } @@ -6731,7 +6750,16 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_UNARY: case GGML_OP_CONV_2D_DW: { - const uint32_t ne = ggml_nelements(dst); + uint32_t ne = ggml_nelements(dst); + if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + // Convert from number of logical elements to 2- or 4-byte units. + ne /= ggml_blck_size(src0->type); + if ((ggml_type_size(src0->type) % 4) == 0) { + ne *= ggml_type_size(src0->type) / 4; + } else { + ne *= ggml_type_size(src0->type) / 2; + } + } if (ne > 262144) { elements = { 512, 512, CEIL_DIV(ne, 262144) }; } else if (ne > 512) { @@ -7281,8 +7309,19 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); + uint32_t ne = (uint32_t)ggml_nelements(src0); + if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + // Convert from number of logical elements to 2- or 4-byte units. + ne /= ggml_blck_size(src0->type); + if ((ggml_type_size(src0->type) % 4) == 0) { + ne *= ggml_type_size(src0->type) / 4; + } else { + ne *= ggml_type_size(src0->type) / 2; + } + } + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, { - (uint32_t)ggml_nelements(src0), + ne, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, @@ -9264,8 +9303,7 @@ static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_ try { ptr = ggml_vk_host_malloc(vk_instance.devices[0], size); } catch (vk::SystemError& e) { - std::cerr << "ggml_vulkan: Failed to allocate pinned memory." << std::endl; - std::cerr << "ggml_vulkan: " << e.what() << std::endl; + GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n", e.what()); // fallback to cpu buffer return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); } @@ -9867,6 +9905,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { return true; } + + // We can handle copying from a type to the same type if it's + // contiguous (memcpy). We use f16 or f32 shaders to do the copy, + // so the type/block size must be a multiple of 4. + if (src0_type == src1_type && + ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op) && + (ggml_type_size(src0_type) % 2) == 0) { + return true; + } return false; } break; case GGML_OP_REPEAT: diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 4f84e56b3d3aa..1499eb08a5dd9 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -2,6 +2,22 @@ #include "ggml.h" +void llama_hparams::set_swa_pattern(uint32_t n_pattern) { + for (uint32_t il = 0; il < n_layer; ++il) { + swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + } +} + +bool llama_hparams::is_swa_any() const { + for (uint32_t il = 0; il < n_layer; ++il) { + if (swa_layers[il]) { + return true; + } + } + + return false; +} + uint32_t llama_hparams::n_head(uint32_t il) const { if (il < n_layer) { return n_head_arr[il]; @@ -72,7 +88,7 @@ uint32_t llama_hparams::n_embd_v_s() const { bool llama_hparams::is_swa(uint32_t il) const { if (il < n_layer) { - return n_swa_pattern == 0 || (il % n_swa_pattern < (n_swa_pattern - 1)); + return swa_layers[il]; } GGML_ABORT("fatal error"); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 5222eedcfb099..2d72eab180ad0 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -102,20 +102,12 @@ struct llama_hparams { // Sliding Window Attention (SWA) llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; - - uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA) - uint32_t n_swa_pattern = 1; // this value n means that every nth layer is dense (i.e. non-SWA) - // by default n == 1, all layers are dense - // note that if n_swa_pattern == 0, all layers are SWA - // example: n_swa_pattern = 3 - // il == 0: swa - // il == 1: swa - // il == 2: dense - // il == 3: swa - // il == 4: swa - // il == 5: dense - // il == 6: swa - // etc ... + // the size of the sliding window (0 - no SWA) + uint32_t n_swa = 0; + // if swa_layers[il] == true, then layer il is SWA + // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA) + // by default, all layers are dense + std::array swa_layers; // for State Space Models uint32_t ssm_d_conv = 0; @@ -153,6 +145,23 @@ struct llama_hparams { enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + // this value n_pattern means that every nth layer is dense (i.e. non-SWA) + // note that if n_pattern == 0, all layers are SWA + // if n_pattern == 1, all layers are dense + // example: n_pattern = 3 + // il == 0: swa + // il == 1: swa + // il == 2: dense + // il == 3: swa + // il == 4: swa + // il == 5: dense + // il == 6: swa + // etc ... + void set_swa_pattern(uint32_t n_pattern); + + // return true if one of the layers is SWA + bool is_swa_any() const; + uint32_t n_head(uint32_t il = 0) const; uint32_t n_head_kv(uint32_t il = 0) const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3735e3c16f0d8..81b052e1b1a47 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -463,11 +463,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { GGML_ASSERT(hparams.n_expert_used == 0); } - // zero-out the array hparams std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); + + std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -574,7 +577,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick - hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full + hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full switch (hparams.n_expert) { case 16: type = LLM_TYPE_17B_16E; break; @@ -863,7 +866,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_NONE; hparams.n_swa = 0; - hparams.n_swa_pattern = 1; + hparams.set_swa_pattern(1); } } break; case LLM_ARCH_PHIMOE: @@ -935,7 +938,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; // default value of gemma 2 - hparams.n_swa_pattern = 2; + hparams.set_swa_pattern(2); hparams.attn_soft_cap = true; ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); @@ -953,7 +956,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_GEMMA3: { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa_pattern = 6; + hparams.set_swa_pattern(6); hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_scale_train_swa = 1.0f; @@ -1038,7 +1041,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_COHERE2: { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa_pattern = 4; + hparams.set_swa_pattern(4); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); @@ -4320,7 +4323,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); - LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern); + LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); @@ -13216,7 +13219,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - GGML_ASSERT(hparams.n_swa_pattern != 1); + GGML_ASSERT(hparams.is_swa_any()); res = new llama_kv_cache_unified_iswa( *this, @@ -13230,7 +13233,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_batch, padding); } else { - GGML_ASSERT(hparams.n_swa_pattern == 1); + GGML_ASSERT(!hparams.is_swa_any()); res = new llama_kv_cache_unified( *this, diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index 5254b2821e504..b79094c0a48b6 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -12,17 +12,7 @@ size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks) { size_t n_tokens = 0; for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) { auto chunk = mtmd_input_chunks_get(chunks, i); - auto chunk_type = mtmd_input_chunk_get_type(chunk); - if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - size_t n_tokens_text; - mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text); - n_tokens += n_tokens_text; - } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { - auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk); - n_tokens += mtmd_image_tokens_get_n_tokens(tokens_image); - } else { - GGML_ASSERT(false && "chunk type not supported"); - } + n_tokens += mtmd_input_chunk_get_n_tokens(chunk); } return n_tokens; } @@ -31,17 +21,7 @@ llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks) { llama_pos n_pos = 0; for (size_t i = 0; i < mtmd_input_chunks_size(chunks); i++) { auto chunk = mtmd_input_chunks_get(chunks, i); - auto chunk_type = mtmd_input_chunk_get_type(chunk); - if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - size_t n_tokens_text; - mtmd_input_chunk_get_tokens_text(chunk, &n_tokens_text); - n_pos += n_tokens_text; - } else if (chunk_type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { - auto tokens_image = mtmd_input_chunk_get_tokens_image(chunk); - n_pos += mtmd_image_tokens_get_n_pos(tokens_image); - } else { - GGML_ASSERT(false && "chunk type not supported"); - } + n_pos += mtmd_input_chunk_get_n_pos(chunk); } return n_pos; } diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 344fe0b07dcf7..d3f3cf3a061de 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -751,6 +751,10 @@ const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap) { return bitmap->data.data(); } +size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap) { + return bitmap->data.size(); +} + bool mtmd_bitmap_is_audio(const mtmd_bitmap * bitmap) { return bitmap->is_audio; } diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 0f4f9c62b7e97..2c722b012ea05 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -119,11 +119,12 @@ MTMD_API bool mtmd_support_audio(mtmd_context * ctx); // the data is in float format (PCM F32) MTMD_API mtmd_bitmap * mtmd_bitmap_init (uint32_t nx, uint32_t ny, const unsigned char * data); MTMD_API mtmd_bitmap * mtmd_bitmap_init_from_audio(size_t n_samples, const float * data); -MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap); -MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap); -MTMD_API const unsigned char * mtmd_bitmap_get_data(const mtmd_bitmap * bitmap); -MTMD_API bool mtmd_bitmap_is_audio(const mtmd_bitmap * bitmap); -MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap); +MTMD_API uint32_t mtmd_bitmap_get_nx (const mtmd_bitmap * bitmap); +MTMD_API uint32_t mtmd_bitmap_get_ny (const mtmd_bitmap * bitmap); +MTMD_API const unsigned char * mtmd_bitmap_get_data (const mtmd_bitmap * bitmap); +MTMD_API size_t mtmd_bitmap_get_n_bytes(const mtmd_bitmap * bitmap); +MTMD_API bool mtmd_bitmap_is_audio (const mtmd_bitmap * bitmap); +MTMD_API void mtmd_bitmap_free (mtmd_bitmap * bitmap); // bitmap ID is optional, but useful for KV cache tracking // these getters/setters are dedicated functions, so you can for example calculate the hash of the image based on mtmd_bitmap_get_data() MTMD_API const char * mtmd_bitmap_get_id(const mtmd_bitmap * bitmap); @@ -322,6 +323,7 @@ struct bitmap { uint32_t nx() { return mtmd_bitmap_get_nx(ptr.get()); } uint32_t ny() { return mtmd_bitmap_get_ny(ptr.get()); } const unsigned char * data() { return mtmd_bitmap_get_data(ptr.get()); } + size_t n_bytes() { return mtmd_bitmap_get_n_bytes(ptr.get()); } std::string id() { return mtmd_bitmap_get_id(ptr.get()); } void set_id(const char * id) { mtmd_bitmap_set_id(ptr.get(), id); } }; diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 02fb00339ec8d..3f1d3f31dcbf9 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 1a08e30d28751..01afeafa0ff57 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1891,6 +1891,7 @@ struct server_context { float slot_prompt_similarity = 0.0f; common_chat_templates_ptr chat_templates; + oaicompat_parser_options oai_parser_opt; ~server_context() { mtmd_free(mctx); @@ -2086,6 +2087,15 @@ struct server_context { } metrics.init(); + + oai_parser_opt = { + /* use_jinja */ params_base.use_jinja, + /* prefill_assistant */ params_base.prefill_assistant, + /* reasoning_format */ params_base.reasoning_format, + /* common_chat_templates */ chat_templates.get(), + /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, + /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, + }; } server_slot * get_slot_by_id(int id) { @@ -4092,7 +4102,10 @@ int main(int argc, char ** argv) { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model.path }, - { "modalities", json{{"vision", ctx_server.mctx != nullptr}} }, // TODO: add more in the future + { "modalities", json{ + {"vision", ctx_server.oai_parser_opt.allow_image}, + {"audio", ctx_server.oai_parser_opt.allow_audio}, + } }, { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, @@ -4183,10 +4196,10 @@ int main(int argc, char ** argv) { for (auto & file : files) { mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size())); if (!bmp.ptr) { - throw std::runtime_error("Failed to load image"); + throw std::runtime_error("Failed to load image or audio file"); } // calculate bitmap hash (for KV caching) - std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3); + std::string hash = fnv_hash(bmp.data(), bmp.n_bytes()); bmp.set_id(hash.c_str()); bitmaps.entries.push_back(std::move(bmp)); } @@ -4418,7 +4431,7 @@ int main(int argc, char ** argv) { OAICOMPAT_TYPE_NONE); // infill is not OAI compatible }; - const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + const auto handle_chat_completions = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { LOG_DBG("request: %s\n", req.body.c_str()); if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); @@ -4427,13 +4440,9 @@ int main(int argc, char ** argv) { auto body = json::parse(req.body); std::vector files; - json data = oaicompat_completion_params_parse( + json data = oaicompat_chat_params_parse( body, - params.use_jinja, - params.prefill_assistant, - params.reasoning_format, - ctx_server.chat_templates.get(), - ctx_server.mctx, + ctx_server.oai_parser_opt, files); handle_completions_impl( @@ -4446,16 +4455,12 @@ int main(int argc, char ** argv) { }; // same with handle_chat_completions, but without inference part - const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { auto body = json::parse(req.body); std::vector files; // dummy, unused - json data = oaicompat_completion_params_parse( + json data = oaicompat_chat_params_parse( body, - params.use_jinja, - params.prefill_assistant, - params.reasoning_format, - ctx_server.chat_templates.get(), - ctx_server.mctx, + ctx_server.oai_parser_opt, files); res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); }; diff --git a/tools/server/tests/unit/test_vision_api.py b/tools/server/tests/unit/test_vision_api.py index 7cc4096f19e0c..fc63caa134293 100644 --- a/tools/server/tests/unit/test_vision_api.py +++ b/tools/server/tests/unit/test_vision_api.py @@ -30,6 +30,7 @@ def create_server(): ("What is this:\n", "malformed", False, None), ("What is this:\n", "https://google.com/404", False, None), # non-existent image ("What is this:\n", "https://ggml.ai", False, None), # non-image data + # TODO @ngxson : test with multiple images, no images and with audio ] ) def test_vision_chat_completion(prompt, image_url, success, re_content): diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 9c41f9db5ff68..bb27b366ea2d6 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -536,6 +536,7 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons // OAI utils // +// used by /completions endpoint static json oaicompat_completion_params_parse(const json & body) { json llama_params; @@ -580,13 +581,19 @@ static json oaicompat_completion_params_parse(const json & body) { return llama_params; } -static json oaicompat_completion_params_parse( +struct oaicompat_parser_options { + bool use_jinja; + bool prefill_assistant; + common_reasoning_format reasoning_format; + common_chat_templates * tmpls; + bool allow_image; + bool allow_audio; +}; + +// used by /chat/completions endpoint +static json oaicompat_chat_params_parse( const json & body, /* openai api json semantics */ - bool use_jinja, - bool prefill_assistant, - common_reasoning_format reasoning_format, - const struct common_chat_templates * tmpls, - bool allow_non_text, + const oaicompat_parser_options & opt, std::vector & out_files) { json llama_params; @@ -598,11 +605,11 @@ static json oaicompat_completion_params_parse( if (stream) { throw std::runtime_error("Cannot use tools with stream"); } - if (!use_jinja) { + if (!opt.use_jinja) { throw std::runtime_error("tools param requires --jinja flag"); } } - if (!use_jinja) { + if (!opt.use_jinja) { if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) { throw std::runtime_error("Unsupported param: tool_choice"); } @@ -667,12 +674,12 @@ static json oaicompat_completion_params_parse( for (auto & p : content) { std::string type = json_value(p, "type", std::string()); - json image_url = json_value(p, "image_url", json::object()); if (type == "image_url") { - if (!allow_non_text) { - throw std::runtime_error("image input is not supported by this server"); + if (!opt.allow_image) { + throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); } + json image_url = json_value(p, "image_url", json::object()); std::string url = json_value(image_url, "url", std::string()); if (string_starts_with(url, "http")) { // download remote image @@ -712,6 +719,29 @@ static json oaicompat_completion_params_parse( p["type"] = "text"; p["text"] = mtmd_default_marker(); p.erase("image_url"); + + } else if (type == "input_audio") { + if (!opt.allow_audio) { + throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); + } + + json input_audio = json_value(p, "input_audio", json::object()); + std::string data = json_value(input_audio, "data", std::string()); + std::string format = json_value(input_audio, "format", std::string()); + // while we also support flac, we don't allow it here so we matches the OAI spec + if (format != "wav" && format != "mp3") { + throw std::runtime_error("input_audio.format must be either 'wav' or 'mp3'"); + } + auto decoded_data = base64_decode(data); // expected to be base64 encoded + out_files.push_back(decoded_data); + + // replace this chunk with a marker + p["type"] = "text"; + p["text"] = mtmd_default_marker(); + p.erase("input_audio"); + + } else if (type != "text") { + throw std::runtime_error("unsupported content[].type"); } } } @@ -723,9 +753,9 @@ static json oaicompat_completion_params_parse( inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); inputs.grammar = grammar; inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); - inputs.use_jinja = use_jinja; + inputs.use_jinja = opt.use_jinja; inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); - inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; + inputs.extract_reasoning = opt.reasoning_format != COMMON_REASONING_FORMAT_NONE; inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) { throw std::runtime_error("Cannot use custom grammar constraints with tools."); @@ -733,7 +763,7 @@ static json oaicompat_completion_params_parse( // if the assistant message appears at the end of list, we do not add end-of-turn token // for ex. this can be useful to modify the reasoning process in reasoning models - bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && prefill_assistant; + bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant; common_chat_msg last_message; if (prefill_assistant_message) { last_message = inputs.messages.back(); @@ -749,7 +779,7 @@ static json oaicompat_completion_params_parse( } // Apply chat template to the list of messages - auto chat_params = common_chat_templates_apply(tmpls, inputs); + auto chat_params = common_chat_templates_apply(opt.tmpls, inputs); /* Append assistant prefilled message */ if (prefill_assistant_message) { @@ -1040,7 +1070,7 @@ struct server_tokens { private: // disallow accessing these members directly, risking out-of-sync // map a **start** position in tokens to the image chunk - std::unordered_map map_pos_to_image; + std::unordered_map map_pos_to_media; // list of tokens // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token @@ -1051,7 +1081,7 @@ struct server_tokens { // for ex. with input of 5 text tokens and 2 images: // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] // pos 0 1 2 3 4 5 6 7 8 9 - // map_pos_to_image will contain: {5, img0}, {8, img1} + // map_pos_to_media will contain: {5, img0}, {8, img1} public: server_tokens() = default; @@ -1090,15 +1120,15 @@ struct server_tokens { } oss << "\n"; oss << "image pos: "; - for (const auto & it : map_pos_to_image) { + for (const auto & it : map_pos_to_media) { oss << it.first << ", "; } return oss.str(); } const mtmd::input_chunk_ptr & find_chunk(llama_pos pos) const { - auto it = map_pos_to_image.find(pos); - if (it != map_pos_to_image.end()) { + auto it = map_pos_to_media.find(pos); + if (it != map_pos_to_media.end()) { return it->second; } else { throw std::runtime_error("Chunk not found"); @@ -1115,16 +1145,15 @@ struct server_tokens { // will create a copy of the chunk if it contains non-text data void push_back(const mtmd_input_chunk * chunk) { auto type = mtmd_input_chunk_get_type(chunk); - if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { + if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { GGML_ASSERT(has_mtmd); - auto img_tokens = mtmd_input_chunk_get_tokens_image(chunk); - const int n_pos = mtmd_image_tokens_get_n_pos(img_tokens); + const int n_pos = mtmd_input_chunk_get_n_pos(chunk); llama_pos start_pos = tokens.size(); for (int i = 0; i < n_pos; ++i) { tokens.emplace_back(LLAMA_TOKEN_NULL); } mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_pos_to_image[start_pos] = std::move(new_chunk); + map_pos_to_media[start_pos] = std::move(new_chunk); } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { size_t n_tokens; auto text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); @@ -1169,6 +1198,9 @@ struct server_tokens { void keep_first(size_t n) { GGML_ASSERT(n <= tokens.size()); if (has_mtmd) { + if (n == tokens.size()) { + return; // nothing to do + } // we throw an error if we try to remove a token in the middle of an image // for ex. with input of 5 text tokens and 2 images: // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] @@ -1183,10 +1215,10 @@ struct server_tokens { } } // remove all image chunks that are not used anymore - for (auto it = map_pos_to_image.begin(); it != map_pos_to_image.end(); ) { + for (auto it = map_pos_to_media.begin(); it != map_pos_to_media.end(); ) { llama_pos pos = it->first; if (pos >= (llama_pos)n) { - it = map_pos_to_image.erase(it); + it = map_pos_to_media.erase(it); } else { ++it; } @@ -1217,14 +1249,12 @@ struct server_tokens { const auto & a_chunk = find_chunk(i); const auto & b_chunk = b.find_chunk(i); GGML_ASSERT(a_chunk && b_chunk); - const auto * a_img = mtmd_input_chunk_get_tokens_image(a_chunk.get()); - const auto * b_img = mtmd_input_chunk_get_tokens_image(b_chunk.get()); - std::string ai_id = mtmd_image_tokens_get_id(a_img); - std::string bi_id = mtmd_image_tokens_get_id(b_img); - size_t a_pos = mtmd_image_tokens_get_n_pos(a_img); - size_t b_pos = mtmd_image_tokens_get_n_pos(b_img); + std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get()); + std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get()); + size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get()); + size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get()); if (ai_id == bi_id && a_pos == b_pos) { - GGML_ASSERT(a_pos > 0 && "Invalid image token"); // should never happen + GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen i += a_pos - 1; // will be +1 by the for loop continue; } else { @@ -1250,8 +1280,7 @@ struct server_tokens { if (t == LLAMA_TOKEN_NULL) { try { const auto & chunk = find_chunk(i); - const auto * img_tokens = mtmd_input_chunk_get_tokens_image(chunk.get()); - size_t n_pos = mtmd_image_tokens_get_n_pos(img_tokens); + size_t n_pos = mtmd_input_chunk_get_n_pos(chunk.get()); i += n_pos - 1; // will be +1 by the for loop } catch (const std::exception & e) { return false; @@ -1270,22 +1299,21 @@ struct server_tokens { llama_pos n_past, int32_t seq_id, llama_pos & n_pos_out) { - auto it = map_pos_to_image.find(n_past); - if (it == map_pos_to_image.end()) { - throw std::runtime_error("Chunk not found"); - } - SRV_INF("%s\n", "processing image..."); + auto & chunk = find_chunk(n_past); + const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE + ? "image" : "audio"; + SRV_INF("processing %s...\n", name); int32_t n_batch = llama_n_batch(ctx); int64_t t0 = ggml_time_ms(); llama_pos new_n_past = n_past; int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, - it->second.get(), // chunk + chunk.get(), n_past, seq_id, n_batch, true, // logits last &new_n_past); - SRV_INF("image processed in %" PRId64 " ms\n", ggml_time_ms() - t0); + SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); if (result != 0) { LOG_ERR("mtmd_helper_eval failed with status %d", result); n_pos_out = n_past; diff --git a/tools/server/webui/src/components/ChatInputExtraContextItem.tsx b/tools/server/webui/src/components/ChatInputExtraContextItem.tsx index 4f28f887482a6..2d4179ea4703e 100644 --- a/tools/server/webui/src/components/ChatInputExtraContextItem.tsx +++ b/tools/server/webui/src/components/ChatInputExtraContextItem.tsx @@ -1,4 +1,8 @@ -import { DocumentTextIcon, XMarkIcon } from '@heroicons/react/24/outline'; +import { + DocumentTextIcon, + SpeakerWaveIcon, + XMarkIcon, +} from '@heroicons/react/24/outline'; import { MessageExtra } from '../utils/types'; import { useState } from 'react'; import { classNames } from '../utils/misc'; @@ -66,7 +70,11 @@ export default function ChatInputExtraContextItem({ className="w-14 h-14 flex items-center justify-center" aria-description="Document icon" > - + {item.type === 'audioFile' ? ( + + ) : ( + + )}
@@ -98,6 +106,19 @@ export default function ChatInputExtraContextItem({ src={showingItem.base64Url} alt={`Preview image for ${showingItem.name}`} /> + ) : showingItem.type === 'audioFile' ? ( + ) : (
diff --git a/tools/server/webui/src/components/ChatScreen.tsx b/tools/server/webui/src/components/ChatScreen.tsx
index 09c601ef2366a..c1a6691445507 100644
--- a/tools/server/webui/src/components/ChatScreen.tsx
+++ b/tools/server/webui/src/components/ChatScreen.tsx
@@ -278,6 +278,13 @@ export default function ChatScreen() {
 
 function ServerInfo() {
   const { serverProps } = useAppContext();
+  const modalities = [];
+  if (serverProps?.modalities?.audio) {
+    modalities.push('audio');
+  }
+  if (serverProps?.modalities?.vision) {
+    modalities.push('vision');
+  }
   return (
     
Build: {serverProps?.build_info}
+ {modalities.length > 0 ? ( + <> + Supported modalities: {modalities.join(', ')} + + ) : ( + '' + )}

diff --git a/tools/server/webui/src/components/useChatExtraContext.tsx b/tools/server/webui/src/components/useChatExtraContext.tsx index b9794405a5da5..42765524067e2 100644 --- a/tools/server/webui/src/components/useChatExtraContext.tsx +++ b/tools/server/webui/src/components/useChatExtraContext.tsx @@ -11,6 +11,7 @@ pdfjs.GlobalWorkerOptions.workerSrc = pdfjsWorkerSrc; // This file handles uploading extra context items (a.k.a files) // It allows processing these kinds of files: // - image files (converted to base64) +// - audio files (converted to base64) // - text files (including code files) // - pdf (converted to text) @@ -41,96 +42,73 @@ export function useChatExtraContext(): ChatExtraContextApi { const isSupportVision = serverProps?.modalities?.vision; - const onFileAdded = (files: File[]) => { - for (const file of files) { - const mimeType = file.type; - console.debug({ mimeType, file }); - if (file.size > 10 * 1024 * 1024) { - toast.error('File is too large. Maximum size is 10MB.'); - break; - } - - if (mimeType.startsWith('image/')) { - if (!isSupportVision) { - toast.error('Multimodal is not supported by this server or model.'); + const onFileAdded = async (files: File[]) => { + try { + for (const file of files) { + const mimeType = file.type; + if (file.size > 10 * 1024 * 1024) { + toast.error('File is too large. Maximum size is 10MB.'); break; } - const reader = new FileReader(); - reader.onload = async (event) => { - if (event.target?.result) { - let base64Url = event.target.result as string; - - if (mimeType === 'image/svg+xml') { - // Convert SVG to PNG - base64Url = await svgBase64UrlToPngDataURL(base64Url); - } - addItems([ - { - type: 'imageFile', - name: file.name, - base64Url, - }, - ]); + if (mimeType.startsWith('image/')) { + if (!isSupportVision) { + toast.error('Multimodal is not supported by this server or model.'); + break; } - }; - reader.readAsDataURL(file); - } else if ( - mimeType.startsWith('video/') || - mimeType.startsWith('audio/') - ) { - toast.error('Video and audio files are not supported yet.'); - break; - } else if (mimeType.startsWith('application/pdf')) { - if (config.pdfAsImage && !isSupportVision) { - toast( - 'Multimodal is not supported, PDF will be converted to text instead of image.' - ); + + let base64Url = await getFileAsBase64(file); + if (mimeType === 'image/svg+xml') { + // Convert SVG to PNG + base64Url = await svgBase64UrlToPngDataURL(base64Url); + } + addItems([ + { + type: 'imageFile', + name: file.name, + base64Url, + }, + ]); + } else if (mimeType.startsWith('video/')) { + toast.error('Video files are not supported yet.'); break; - } + } else if (mimeType.startsWith('audio/')) { + if (!/mpeg|wav/.test(mimeType)) { + toast.error('Only mp3 and wav audio files are supported.'); + break; + } - const promise = - config.pdfAsImage && isSupportVision - ? convertPDFToImage(file).then((base64Urls) => { - addItems( - base64Urls.map((base64Url) => ({ - type: 'imageFile', - name: file.name, - base64Url, - })) - ); - }) - : convertPDFToText(file).then((content) => { - if (isSupportVision) { - toast.success( - 'PDF file converted to text. You can also convert it to image, see in Settings.' - ); - } - addItems([ - { - type: 'textFile', - name: file.name, - content, - }, - ]); - }); - - promise.catch((error) => { - console.error(error); - toast.error('Failed to parse PDF file.'); - }); - break; - } else { - // Because there can be many text file types (like code file), we will not check the mime type - // and will just check if the file is not binary. - const reader = new FileReader(); - reader.onload = (event) => { - if (event.target?.result) { - const content = event.target.result as string; - if (!isLikelyNotBinary(content)) { - toast.error('File is binary. Please upload a text file.'); - return; - } + // plain base64, not a data URL + const base64Data = await getFileAsBase64(file, false); + addItems([ + { + type: 'audioFile', + name: file.name, + mimeType, + base64Data, + }, + ]); + } else if (mimeType.startsWith('application/pdf')) { + if (config.pdfAsImage && !isSupportVision) { + toast( + 'Multimodal is not supported, PDF will be converted to text instead of image.' + ); + break; + } + + if (config.pdfAsImage && isSupportVision) { + // Convert PDF to images + const base64Urls = await convertPDFToImage(file); + addItems( + base64Urls.map((base64Url) => ({ + type: 'imageFile', + name: file.name, + base64Url, + })) + ); + } else { + // Convert PDF to text + const content = await convertPDFToText(file); addItems([ { type: 'textFile', @@ -138,10 +116,40 @@ export function useChatExtraContext(): ChatExtraContextApi { content, }, ]); + if (isSupportVision) { + toast.success( + 'PDF file converted to text. You can also convert it to image, see in Settings.' + ); + } } - }; - reader.readAsText(file); + break; + } else { + // Because there can be many text file types (like code file), we will not check the mime type + // and will just check if the file is not binary. + const reader = new FileReader(); + reader.onload = (event) => { + if (event.target?.result) { + const content = event.target.result as string; + if (!isLikelyNotBinary(content)) { + toast.error('File is binary. Please upload a text file.'); + return; + } + addItems([ + { + type: 'textFile', + name: file.name, + content, + }, + ]); + } + }; + reader.readAsText(file); + } } + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + const errorMessage = `Error processing file: ${message}`; + toast.error(errorMessage); } }; @@ -154,6 +162,25 @@ export function useChatExtraContext(): ChatExtraContextApi { }; } +async function getFileAsBase64(file: File, outputUrl = true): Promise { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = (event) => { + if (event.target?.result) { + let result = event.target.result as string; + if (!outputUrl) { + // remove base64 url prefix and correct characters + result = result.substring(result.indexOf(',') + 1); + } + resolve(result); + } else { + reject(new Error('Failed to read file.')); + } + }; + reader.readAsDataURL(file); + }); +} + async function getFileAsBuffer(file: File): Promise { return new Promise((resolve, reject) => { const reader = new FileReader(); diff --git a/tools/server/webui/src/utils/misc.ts b/tools/server/webui/src/utils/misc.ts index ba760e83bb282..d60a68cd2431b 100644 --- a/tools/server/webui/src/utils/misc.ts +++ b/tools/server/webui/src/utils/misc.ts @@ -89,6 +89,14 @@ export function normalizeMsgsForAPI(messages: Readonly) { type: 'image_url', image_url: { url: extra.base64Url }, }); + } else if (extra.type === 'audioFile') { + contentArr.push({ + type: 'input_audio', + input_audio: { + data: extra.base64Data, + format: /wav/.test(extra.mimeType) ? 'wav' : 'mp3', + }, + }); } else { throw new Error('Unknown extra type'); } diff --git a/tools/server/webui/src/utils/types.ts b/tools/server/webui/src/utils/types.ts index ba673dd9432da..ea7d641dc748b 100644 --- a/tools/server/webui/src/utils/types.ts +++ b/tools/server/webui/src/utils/types.ts @@ -51,6 +51,7 @@ export interface Message { export type MessageExtra = | MessageExtraTextFile | MessageExtraImageFile + | MessageExtraAudioFile | MessageExtraContext; export interface MessageExtraTextFile { @@ -65,6 +66,13 @@ export interface MessageExtraImageFile { base64Url: string; } +export interface MessageExtraAudioFile { + type: 'audioFile'; + name: string; + base64Data: string; + mimeType: string; +} + export interface MessageExtraContext { type: 'context'; name: string; @@ -79,6 +87,10 @@ export type APIMessageContentPart = | { type: 'image_url'; image_url: { url: string }; + } + | { + type: 'input_audio'; + input_audio: { data: string; format: 'wav' | 'mp3' }; }; export type APIMessage = { @@ -120,6 +132,7 @@ export interface LlamaCppServerProps { n_ctx: number; modalities?: { vision: boolean; + audio: boolean; }; // TODO: support params }