Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
74a989b
feat(deepseek-ocr): deepseek-ocr support
chenghuaWang Oct 22, 2025
8eebf54
feat(deepseek_ocr): implement conversation management and preprocessi…
chenghuaWang Oct 23, 2025
60f6f92
feat(cpu): add interpolate and pad operations with full interpolation…
chenghuaWang Oct 23, 2025
bf82e6b
feat(image): add dynamic image preprocessing and cropping support
chenghuaWang Oct 23, 2025
6770f87
feat(interpolate): add antialias support and remove keep_aspect_ratio
chenghuaWang Oct 23, 2025
1515cdb
feat(deepseek_ocr): add mlp projector linear impl type configuration
chenghuaWang Oct 23, 2025
e858d25
feat(deepseek_ocr): add message formatting and model inference support
chenghuaWang Oct 23, 2025
4ca7a07
feat(ext): add tokenizers-cpp and opencv-mobile as optional extensions
chenghuaWang Oct 23, 2025
0eabe59
feat(ocr): add llvm-project submodule and update deepseek ocr model
chenghuaWang Oct 24, 2025
78a17fb
feat(cpu): add StackOp implementation and integrate into DeepSeek-OCR…
chenghuaWang Oct 24, 2025
593258e
feat(tokenizer): implement UTF-8 support for DeepSeek OCR tokenizer
chenghuaWang Oct 24, 2025
f76bce9
refactor(ext): replace tokenizers-cpp with tokenizers submodule
chenghuaWang Oct 24, 2025
73b74f2
feat(tokenizer): implement deepseek-ocr tokenizer with BPE and UTF-8 …
chenghuaWang Oct 25, 2025
a556ad2
feat(deepseek_ocr): improve tokenizer and add support for new special…
chenghuaWang Oct 25, 2025
7b2f501
feat(deepseek_ocr): refactor model loading and initialization with co…
chenghuaWang Oct 26, 2025
6732314
feat(cpu): implement MaskedScatterOp for CPU backend
chenghuaWang Oct 26, 2025
e849dbc
feat(deepseek_ocr): add DeepseekV2MLP, MoEGate, and DeepseekV2MoE mod…
chenghuaWang Oct 26, 2025
a80a973
fix(Tensor): cast rank to int32_t for negative index handling
chenghuaWang Oct 26, 2025
7e39401
docs(contribute): rename guidelines.md to guidelines.rst
chenghuaWang Oct 27, 2025
9b843b6
feat(tensor): support negative dim in repeat operation
chenghuaWang Oct 27, 2025
324cd50
feat(cpu): implement optimized softmax for last dimension cases
chenghuaWang Oct 27, 2025
00d787a
feat(cpu): add Tracy profiler option and update quantization config
chenghuaWang Oct 28, 2025
e861deb
feat(deepseek_ocr): update model paths and quantization config
chenghuaWang Oct 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mllm/backends/cpu/CPUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
#include "mllm/backends/cpu/ops/ConcatOp.hpp"
#include "mllm/backends/cpu/ops/ContiguousOp.hpp"
#include "mllm/backends/cpu/ops/Conv1DOp.hpp"
#include "mllm/backends/cpu/ops/Conv2DOp.hpp"
#include "mllm/backends/cpu/ops/Conv3DOp.hpp"
#include "mllm/backends/cpu/ops/CopyOp.hpp"
#include "mllm/backends/cpu/ops/ElewiseOps.hpp"
#include "mllm/backends/cpu/ops/EmbeddingOp.hpp"
#include "mllm/backends/cpu/ops/FillOp.hpp"
#include "mllm/backends/cpu/ops/FlashAttention2Op.hpp"
#include "mllm/backends/cpu/ops/GELUOp.hpp"
#include "mllm/backends/cpu/ops/LayerNorm2DOp.hpp"
#include "mllm/backends/cpu/ops/RadixAttnOp.hpp"
#include "mllm/backends/cpu/ops/ReLUOp.hpp"
#include "mllm/backends/cpu/ops/GraphOps.hpp"
Expand Down Expand Up @@ -60,7 +62,8 @@ CPUBackend::CPUBackend() : Backend(kCPU, createCPUAllocator()) {
CPUFlashAttention2OpFactory, CPUSliceOpFactory, CPUVisionRoPEOpFactory, CPUParamOpFactory,
CPUMultimodalRoPEOpFactory, CPURoPEOpFactory, CPUCausalMaskOpFactory, CPUConv1DOpFactory, CPUConv3DOpFactory,
CPUSTFTOpFactory, CPUISTFTOpFactory, CPUIndexOpFactory, CPUTopKOpFactory, CPUClipOpFactory, CPUMeanOpFactory,
CPUKVCacheOpFactory, CPUPagedAttnOpFactory, CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory>();
CPUKVCacheOpFactory, CPUPagedAttnOpFactory, CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory,
CPUConv2DOpFactory, CPULayerNorm2DOpFactory>();
}

std::shared_ptr<CPUBackend> createCPUBackend() { return std::make_shared<CPUBackend>(); }
Expand Down
2 changes: 2 additions & 0 deletions mllm/backends/cpu/kernels/Kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include "mllm/backends/cpu/kernels/arm/conv3d.hpp" // IWYU pragma: export
#include "mllm/backends/cpu/kernels/arm/linear/kai.hpp" // IWYU pragma: export
#include "mllm/backends/cpu/kernels/arm/relu.hpp" // IWYU pragma: export
#include "mllm/backends/cpu/kernels/arm/conv2d.hpp" // IWYU pragma: export
#include "mllm/backends/cpu/kernels/arm/layernorm2d.hpp" // IWYU pragma: export
#include "mllm/backends/cpu/kernels/arm/mllm_blas/mllm_blas_sgemm.hpp" // IWYU pragma: export
#else
#include "mllm/backends/cpu/kernels/common/gelu-inl.hpp" // IWYU pragma: export
Expand Down
100 changes: 100 additions & 0 deletions mllm/backends/cpu/kernels/arm/conv2d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) MLLM Team.
// Licensed under the MIT License.
#include "mllm/backends/cpu/kernels/arm/conv2d.hpp"

#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM)

#include <arm_neon.h>

namespace mllm::cpu::arm {

void conv2d_fp32_im2col_input(const float* input_data, const int channels, const int height, const int width,
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h,
const int stride_w, const int dilation_h, const int dilation_w, float* col_data) {
const int output_h = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int output_w = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
const int channel_size = height * width;

const float32x4_t vzero = vdupq_n_f32(0.0f);

for (int channel = 0; channel < channels; ++channel) {
for (int kernel_y = 0; kernel_y < kernel_h; ++kernel_y) {
for (int kernel_x = 0; kernel_x < kernel_w; ++kernel_x) {
const int input_start_y = -pad_h + kernel_y * dilation_h;
const int input_start_x = -pad_w + kernel_x * dilation_w;

for (int out_y = 0; out_y < output_h; ++out_y) {
const int cur_input_y = input_start_y + out_y * stride_h;

if (static_cast<unsigned>(cur_input_y) >= static_cast<unsigned>(height)) {
for (int out_x = 0; out_x < output_w; out_x += 4) {
if (out_x + 3 < output_w) {
vst1q_f32(col_data, vzero);
col_data += 4;
} else {
for (int i = 0; i < output_w - out_x; ++i) { *col_data++ = 0.0f; }
}
}
} else {
int out_x = 0;
for (; out_x + 3 < output_w; out_x += 4) {
const int input_x0 = input_start_x + (out_x + 0) * stride_w;
const int input_x1 = input_start_x + (out_x + 1) * stride_w;
const int input_x2 = input_start_x + (out_x + 2) * stride_w;
const int input_x3 = input_start_x + (out_x + 3) * stride_w;

const float val0 = (static_cast<unsigned>(input_x0) < static_cast<unsigned>(width))
? input_data[cur_input_y * width + input_x0]
: 0.0f;
const float val1 = (static_cast<unsigned>(input_x1) < static_cast<unsigned>(width))
? input_data[cur_input_y * width + input_x1]
: 0.0f;
const float val2 = (static_cast<unsigned>(input_x2) < static_cast<unsigned>(width))
? input_data[cur_input_y * width + input_x2]
: 0.0f;
const float val3 = (static_cast<unsigned>(input_x3) < static_cast<unsigned>(width))
? input_data[cur_input_y * width + input_x3]
: 0.0f;

float32x4_t v_data = {val0, val1, val2, val3};
vst1q_f32(col_data, v_data);
col_data += 4;
}

for (; out_x < output_w; ++out_x) {
const int cur_input_x = input_start_x + out_x * stride_w;
if (static_cast<unsigned>(cur_input_x) < static_cast<unsigned>(width)) {
*col_data++ = input_data[cur_input_y * width + cur_input_x];
} else {
*col_data++ = 0.0f;
}
}
}
}
}
}
input_data += channel_size;
}
}

void conv2d_fp32_im2col_weight(const float* src_weight, float* packed_weight, int out_channels, int in_channels, int kernel_h,
int kernel_w) {
int M = out_channels;
int K = in_channels * kernel_h * kernel_w;

for (int o = 0; o < out_channels; ++o) {
for (int i = 0; i < in_channels; ++i) {
for (int h = 0; h < kernel_h; ++h) {
for (int w = 0; w < kernel_w; ++w) {
int src_idx = h * (kernel_w * in_channels * out_channels) + w * (in_channels * out_channels) + i * (out_channels) + o;
int dst_idx = o * (in_channels * kernel_h * kernel_w) + i * (kernel_h * kernel_w) + h * (kernel_w) + w;
packed_weight[dst_idx] = src_weight[src_idx];
}
}
}
}
}

} // namespace mllm::cpu::arm

#endif
34 changes: 34 additions & 0 deletions mllm/backends/cpu/kernels/arm/conv2d.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) MLLM Team.
// Licensed under the MIT License.

#pragma once

#include "mllm/utils/CPUArchHelper.hpp"

#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM)

#include <arm_neon.h>

namespace mllm::cpu::arm {

//===----------------------------------------------------------------------===//
// Im2col.
//
// Reformat your inputs to im2col's input
// Reformat your weights to im2col's weight
// After those 2 parts, do gemm(weight, input)
//===----------------------------------------------------------------------===//
void conv2d_fp32_im2col_input(const float* input_data, const int channels, const int height, const int width,
const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, const int stride_h,
const int stride_w, const int dilation_h, const int dilation_w, float* col_data);

// Inputs weight format should in [Out_Channels, In_Channels, Kernel_H, Kernel_W]
// Output weight format should in [M x K]
//
//
// This kernel is not performance sensitive !!! We only need to pack weight once !
void conv2d_fp32_im2col_weight(const float* src_weight, float* packed_weight, int out_channels, int in_channels, int kernel_h,
int kernel_w);

} // namespace mllm::cpu::arm
#endif
89 changes: 89 additions & 0 deletions mllm/backends/cpu/kernels/arm/layernorm2d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) MLLM Team.
// Licensed under the MIT License.
#include "mllm/backends/cpu/kernels/arm/conv2d.hpp"

#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM)

#include <cmath>
#include <arm_neon.h>

namespace mllm::cpu::arm {
void layernorm2d_fp32(const float* x, const float* weight, const float* bias, float* y, int N, int C, int H, int W, float eps) {
const int spatial_dim = H * W;

for (int n = 0; n < N; ++n) {
for (int i = 0; i < spatial_dim; ++i) {
const float* x_ptr = x + n * C * spatial_dim + i;
float* y_ptr = y + n * C * spatial_dim + i;

float sum = 0.0f;
#if defined(__ARM_NEON)
float32x4_t sum_vec = vdupq_n_f32(0.0f);
int c = 0;
for (; c <= C - 4; c += 4) {
float32x4_t x_vec = {x_ptr[c * spatial_dim], x_ptr[(c + 1) * spatial_dim], x_ptr[(c + 2) * spatial_dim],
x_ptr[(c + 3) * spatial_dim]};
sum_vec = vaddq_f32(sum_vec, x_vec);
}
sum = vaddvq_f32(sum_vec);
for (; c < C; ++c) { sum += x_ptr[c * spatial_dim]; }
#else
for (int c = 0; c < C; ++c) { sum += x_ptr[c * spatial_dim]; }
#endif
const float mean = sum / C;

float sq_sum = 0.0f;
#if defined(__ARM_NEON)
float32x4_t sq_sum_vec = vdupq_n_f32(0.0f);
float32x4_t mean_vec = vdupq_n_f32(mean);
c = 0;
for (; c <= C - 4; c += 4) {
float32x4_t x_vec = {x_ptr[c * spatial_dim], x_ptr[(c + 1) * spatial_dim], x_ptr[(c + 2) * spatial_dim],
x_ptr[(c + 3) * spatial_dim]};
float32x4_t diff = vsubq_f32(x_vec, mean_vec);
sq_sum_vec = vmlaq_f32(sq_sum_vec, diff, diff); // Fused multiply-accumulate: sq_sum_vec += diff * diff
}
sq_sum = vaddvq_f32(sq_sum_vec);
for (; c < C; ++c) {
float diff = x_ptr[c * spatial_dim] - mean;
sq_sum += diff * diff;
}
#else
for (int c = 0; c < C; ++c) {
float diff = x_ptr[c * spatial_dim] - mean;
sq_sum += diff * diff;
}
#endif
const float variance = sq_sum / C;
const float inv_std = 1.0f / std::sqrt(variance + eps);

#if defined(__ARM_NEON)
float32x4_t inv_std_vec = vdupq_n_f32(inv_std);
c = 0;
for (; c <= C - 4; c += 4) {
float32x4_t x_vec = {x_ptr[c * spatial_dim], x_ptr[(c + 1) * spatial_dim], x_ptr[(c + 2) * spatial_dim],
x_ptr[(c + 3) * spatial_dim]};
float32x4_t weight_vec = vld1q_f32(weight + c);
float32x4_t bias_vec = vld1q_f32(bias + c);

// y = (x - mean) * inv_std
float32x4_t norm_val = vmulq_f32(vsubq_f32(x_vec, mean_vec), inv_std_vec);
// y = y * weight + bias
float32x4_t out_vec = vmlaq_f32(bias_vec, norm_val, weight_vec);

y_ptr[c * spatial_dim] = vgetq_lane_f32(out_vec, 0);
y_ptr[(c + 1) * spatial_dim] = vgetq_lane_f32(out_vec, 1);
y_ptr[(c + 2) * spatial_dim] = vgetq_lane_f32(out_vec, 2);
y_ptr[(c + 3) * spatial_dim] = vgetq_lane_f32(out_vec, 3);
}
for (; c < C; ++c) { y_ptr[c * spatial_dim] = (x_ptr[c * spatial_dim] - mean) * inv_std * weight[c] + bias[c]; }
#else
for (int c = 0; c < C; ++c) { y_ptr[c * spatial_dim] = (x_ptr[c * spatial_dim] - mean) * inv_std * weight[c] + bias[c]; }
#endif
}
}
}

} // namespace mllm::cpu::arm

#endif
18 changes: 18 additions & 0 deletions mllm/backends/cpu/kernels/arm/layernorm2d.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) MLLM Team.
// Licensed under the MIT License.

#pragma once

#include "mllm/utils/CPUArchHelper.hpp"

#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM)

#include <arm_neon.h>

namespace mllm::cpu::arm {

// For NCHW
void layernorm2d_fp32(const float* x, const float* weight, const float* bias, float* y, int N, int C, int H, int W, float eps);

} // namespace mllm::cpu::arm
#endif
Loading