diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 2f06e1e39b225..342678e2b73db 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -503,6 +503,7 @@ extern "C" { GGML_OP_IM2COL, GGML_OP_IM2COL_BACK, GGML_OP_CONV_2D, + GGML_OP_CONV_3D, GGML_OP_CONV_2D_DW, GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_POOL_1D, @@ -1914,6 +1915,23 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 + GGML_API struct ggml_tensor * ggml_conv_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC] + struct ggml_tensor * b, // input [W, H, D, C * N] + int s0, // stride + int s1, + int s2, + int p0, // padding + int p1, + int p2, + int d0, // dilation + int d1, + int d2, + int n_channels, + int n_batch, + int n_channels_out); + enum ggml_op_pool { GGML_OP_POOL_MAX, GGML_OP_POOL_AVG, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index d89cd8f4ef652..255f42848cac8 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1880,6 +1880,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_conv_2d(params, tensor); } break; + case GGML_OP_CONV_3D: + { + ggml_compute_forward_conv_3d(params, tensor); + } break; case GGML_OP_CONV_2D_DW: { ggml_compute_forward_conv_2d_dw(params, tensor); @@ -2247,6 +2251,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_IM2COL: case GGML_OP_IM2COL_BACK: case GGML_OP_CONV_2D: + case GGML_OP_CONV_3D: case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_2D: @@ -2767,6 +2772,7 @@ struct ggml_cplan ggml_graph_plan( } } break; case GGML_OP_CONV_2D: + case GGML_OP_CONV_3D: { cur = GGML_IM2COL_WORK_SIZE; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 854f1c2b49647..ec49d5d47c0a4 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7207,6 +7207,148 @@ void ggml_compute_forward_conv_2d( ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type); } +// ggml_compute_forward_conv_3d + +static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params, + const ggml_tensor * kernel, + const ggml_tensor * src, + ggml_tensor * dst, + ggml_type kernel_type) { + + GGML_ASSERT(ggml_is_contiguous(kernel)); + GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32); + GGML_ASSERT(kernel->type == kernel_type); + + const ggml_type_traits * traits = ggml_get_type_traits(kernel_type); + + const int32_t s0 = dst->op_params[0]; + const int32_t s1 = dst->op_params[1]; + const int32_t s2 = dst->op_params[2]; + const int32_t p0 = dst->op_params[3]; + const int32_t p1 = dst->op_params[4]; + const int32_t p2 = dst->op_params[5]; + const int32_t d0 = dst->op_params[6]; + const int32_t d1 = dst->op_params[7]; + const int32_t d2 = dst->op_params[8]; + const int32_t c = dst->op_params[9]; + const int32_t n = dst->op_params[10]; + const int32_t oc = dst->op_params[11]; + + const int64_t src_w = src->ne[0]; + const int64_t src_h = src->ne[1]; + const int64_t src_d = src->ne[2]; + const int64_t knl_w = kernel->ne[0]; + const int64_t knl_h = kernel->ne[1]; + const int64_t knl_d = kernel->ne[2]; + const int64_t dst_w = dst->ne[0]; + const int64_t dst_h = dst->ne[1]; + const int64_t dst_d = dst->ne[2]; + + const float * src_data = (float *) src->data; + void * knl_data = kernel->data; + float * dst_data = (float *) dst->data; + + const int64_t knl_n_per_channel = knl_w * knl_h * knl_d; + const int64_t knl_n_total = knl_n_per_channel * c; + const int64_t patch_total = n * dst_w * dst_h * dst_d; + + const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float); + const int64_t batch_size = params->wsize / space_per_patch; + const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size; + const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch; + + GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1); + + void * tmp = params->wdata; + + for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) { + const int64_t patch_start_batch = batch_i * patches_per_batch; + const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total); + const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch; + + const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth; + const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread; + const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch); + + for (int64_t p = patch_start; p < patch_end; ++p) { + const int64_t p_in_batch = p % (dst_w * dst_h * dst_d); + const int64_t p_in_depth = p_in_batch % (dst_w * dst_h); + const int64_t batch_idx = p / (dst_w * dst_h * dst_d); + const int64_t dst_z = p_in_batch / (dst_w * dst_h); + const int64_t dst_y = p_in_depth / dst_w; + const int64_t dst_x = p_in_depth % dst_w; + + char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size; + + for (int64_t ic = 0; ic < c; ++ic) { + for (int64_t kz = 0; kz < knl_d; ++kz) { + for (int64_t ky = 0; ky < knl_h; ++ky) { + for (int64_t kx = 0; kx < knl_w; ++kx) { + const int64_t sz = dst_z * s2 + kz * d2 - p2; + const int64_t sy = dst_y * s1 + ky * d1 - p1; + const int64_t sx = dst_x * s0 + kx * d0 - p0; + + int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx; + + float src_val; + if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) { + src_val = 0.0f; + } else { + const int64_t cn_idx = batch_idx * c + ic; + const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]); + src_val = *src_ptr; + } + + char * element_ptr = dst_row + dst_idx * traits->type_size; + if (kernel_type == GGML_TYPE_F32) { + *(float *)element_ptr = src_val; + } else if (kernel_type == GGML_TYPE_F16) { + *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val); + } + } + } + } + } + } + + ggml_barrier(params->threadpool); + + float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size); + ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output); + + ggml_barrier(params->threadpool); + + const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth; + const int64_t permute_start = params->ith * permute_per_thread; + const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch); + + for (int64_t i = permute_start; i < permute_end; ++i) { + const int64_t p = patch_start_batch + i; + const int64_t p_in_batch = p % (dst_w * dst_h * dst_d); + const int64_t p_in_depth = p_in_batch % (dst_w * dst_h); + const int64_t batch_idx = p / (dst_w * dst_h * dst_d); + const int64_t dst_z = p_in_batch / (dst_w * dst_h); + const int64_t dst_y = p_in_depth / dst_w; + const int64_t dst_x = p_in_depth % dst_w; + + for (int64_t ioc = 0; ioc < oc; ++ioc) { + const float value = gemm_output[i * oc + ioc]; + const int64_t ocn_idx = batch_idx * oc + ioc; + float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]); + *dst_ptr = value; + } + } + } +} + +void ggml_compute_forward_conv_3d( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type); +} + // ggml_compute_forward_conv_transpose_2d void ggml_compute_forward_conv_transpose_2d( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index f154afb462498..4040c1081e478 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -70,6 +70,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 55a76f8248c09..60ecca3146f87 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -975,6 +975,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "IM2COL", "IM2COL_BACK", "CONV_2D", + "CONV_3D", "CONV_2D_DW", "CONV_TRANSPOSE_2D", "POOL_1D", @@ -1016,7 +1017,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); +static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1076,6 +1077,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "im2col(x)", "im2col_back(x)", "conv_2d(x)", + "conv_3d(x)", "conv_2d_dw(x)", "conv_transpose_2d(x)", "pool_1d(x)", @@ -1117,7 +1119,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); +static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4476,6 +4478,56 @@ struct ggml_tensor * ggml_conv_2d_direct( return result; } +// ggml_conv_3d + +struct ggml_tensor * ggml_conv_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int s2, + int p0, + int p1, + int p2, + int d0, + int d1, + int d2, + int c, + int n, + int oc) { + + GGML_ASSERT(a->ne[3] == (int64_t) c * oc); + GGML_ASSERT(b->ne[3] == (int64_t) c * n); + + int64_t ne[4]; + ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1); + ne[2] = ggml_calc_conv_output_size(b->ne[2], a->ne[2], s2, p2, d2); + ne[3] = (int64_t) oc * n; + + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + ggml_set_op_params_i32(result, 0, s0); + ggml_set_op_params_i32(result, 1, s1); + ggml_set_op_params_i32(result, 2, s2); + ggml_set_op_params_i32(result, 3, p0); + ggml_set_op_params_i32(result, 4, p1); + ggml_set_op_params_i32(result, 5, p2); + ggml_set_op_params_i32(result, 6, d0); + ggml_set_op_params_i32(result, 7, d1); + ggml_set_op_params_i32(result, 8, d2); + ggml_set_op_params_i32(result, 9, c); + ggml_set_op_params_i32(result, 10, n); + ggml_set_op_params_i32(result, 11, oc); + + result->op = GGML_OP_CONV_3D; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // ggml_conv_transpose_2d_p0 static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d29779cd12b22..4bbe79ca20a7b 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4063,6 +4063,75 @@ struct test_conv_2d_dw : public test_case { } }; +// GGML_OP_CONV_3D +struct test_conv_3d : public test_case { + // Logical 5D dimensions + const int64_t N, IC, ID, IH, IW; + const int64_t OC, KD, KH, KW; + // Conv params + const int s0, s1, s2; + const int p0, p1, p2; + const int d0, d1, d2; + // Types + const ggml_type type_kernel; + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "CONV_3D"; + } + + std::string vars() override { + return VARS_TO_STR11(N, IC, ID, IH, IW, OC, KD, KH, KW, s0, s1) + "," + + VARS_TO_STR8(s2, p0, p1, p2, d0, d1, d2, type_kernel); + } + + double max_nmse_err() override { + return 5e-4; + } + + uint64_t op_flops(ggml_tensor * t) override { + GGML_UNUSED(t); + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; + }; + const int64_t OD = calc_conv_output_size(ID, KD, s2, p2, d2); + const int64_t OH = calc_conv_output_size(IH, KH, s1, p1, d1); + const int64_t OW = calc_conv_output_size(IW, KW, s0, p0, d0); + + return (uint64_t)N * OC * OD * OH * OW * (2 * IC * KD * KH * KW - 1); + } + + test_conv_3d( + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, + int64_t OC, int64_t KD, int64_t KH, int64_t KW, + int s0, int s1, int s2, + int p0, int p1, int p2, + int d0, int d1, int d2, + ggml_type type_kernel + ) : N(N), IC(IC), ID(ID), IH(IH), IW(IW), + OC(OC), KD(KD), KH(KH), KW(KW), + s0(s0), s1(s1), s2(s2), + p0(p0), p1(p1), p2(p2), + d0(d0), d1(d1), d2(d2), + type_kernel(type_kernel) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + // GGML input tensor is packed as [W, H, D, C*N] + const int64_t ne_input[] = {IW, IH, ID, IC * N}; + ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input); + ggml_set_name(input, "input"); + + // GGML kernel tensor is packed as [KW, KH, KD, IC*OC] + const int64_t ne_kernel[] = {KW, KH, KD, IC * OC}; + ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel); + ggml_set_name(kernel, "kernel"); + + ggml_tensor * out = ggml_conv_3d(ctx, kernel, input, s0, s1, s2, p0, p1, p2, d0, d1, d2, (int)IC, (int)N, (int)OC); + ggml_set_name(out, "out"); + return out; + } +}; + // GGML_OP_CONCAT struct test_concat : public test_case { const ggml_type type; @@ -5461,6 +5530,61 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false)); test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true)); + // CONV_3D + auto calc_conv_output_size_3d = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; + }; + + for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (int N : {1, 2}) { + for (int IC : {1, 3}) { + for (int OC : {1, 4}) { + for (int s0 : {1, 2}) { + for (int p1 : {0, 1}) { + for (int d2 : {1, 2}) { + int64_t IW = 20, IH = 22, ID = 18; + int64_t KW = 3, KH = 3, KD = 3; + int s1 = s0, s2 = s0; + int p0 = p1, p2 = p1; + int d0 = d2, d1 = d2; + + if (calc_conv_output_size_3d(IW, KW, s0, p0, d0) <= 0 || + calc_conv_output_size_3d(IH, KH, s1, p1, d1) <= 0 || + calc_conv_output_size_3d(ID, KD, s2, p2, d2) <= 0) { + continue; + } + test_cases.emplace_back(new test_conv_3d( + N, IC, ID, IH, IW, + OC, KD, KH, KW, + s0, s1, s2, p0, p1, p2, d0, d1, d2, + kernel_type)); + + // Asymmetric kernel and params + int64_t asym_KW = 5, asym_KH = 1, asym_KD = 3; + int asym_s0 = 2, asym_s1 = 1, asym_s2 = 1; + int asym_p0 = 2, asym_p1 = 0, asym_p2 = 1; + int asym_d0 = 1, asym_d1 = 1, asym_d2 = 2; + + if (calc_conv_output_size_3d(IW, asym_KW, asym_s0, asym_p0, asym_d0) <= 0 || + calc_conv_output_size_3d(IH, asym_KH, asym_s1, asym_p1, asym_d1) <= 0 || + calc_conv_output_size_3d(ID, asym_KD, asym_s2, asym_p2, asym_d2) <= 0) { + continue; + } + test_cases.emplace_back(new test_conv_3d( + N, IC, ID, IH, IW, + OC, asym_KD, asym_KH, asym_KW, + asym_s0, asym_s1, asym_s2, asym_p0, asym_p1, asym_p2, asym_d0, asym_d1, asym_d2, + kernel_type)); + } + } + } + } + } + } + // Case with kernel size 1 + test_cases.emplace_back(new test_conv_3d(1, 4, 8, 8, 8, 8, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, kernel_type)); + } + for(uint32_t Cout : {1, 9}){ for(uint32_t Cin : {1, 7}){ for(uint32_t K : {1, 3, 1337}){