Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/include/ggml-cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ extern "C" {
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);

GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
Expand Down
1 change: 1 addition & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1403,6 +1403,7 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);

// note: casting from f32 to i32 will discard the fractional part
GGML_API struct ggml_tensor * ggml_cast(
struct ggml_context * ctx,
struct ggml_tensor * a,
Expand Down
15 changes: 14 additions & 1 deletion ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,9 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
[GGML_TYPE_I32] = {
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32,
},
};

const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
Expand Down Expand Up @@ -2691,7 +2694,10 @@ struct ggml_cplan ggml_graph_plan(
if (ggml_is_quantized(node->type) ||
// F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
(node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
(node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
(node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16) ||
// conversion between F32 and I32
(node->src[0]->type == GGML_TYPE_F32 && node->src[1] && node->src[1]->type == GGML_TYPE_I32) ||
(node->src[0]->type == GGML_TYPE_I32 && node->src[1] && node->src[1]->type == GGML_TYPE_F32)) {
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
}
} break;
Expand Down Expand Up @@ -3283,6 +3289,13 @@ void ggml_cpu_fp32_to_bf16(const float * x, ggml_bf16_t * y, int64_t n) {
}
}

void ggml_cpu_fp32_to_i32(const float * x, int32_t * y, int64_t n) {
int64_t i = 0;
for (; i < n; ++i) {
y[i] = x[i];
}
}

void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) {
int64_t i = 0;
#if defined(__AVX2__)
Expand Down
160 changes: 160 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,24 @@ static void ggml_compute_forward_dup_f32(
id += ne00 * (ne01 - ir1);
}
}
} else if (dst->type == GGML_TYPE_I32) {
size_t id = 0;
int32_t * dst_ptr = (int32_t *) dst->data;

for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
id += ne00 * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
for (int i00 = 0; i00 < ne00; i00++) {
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);

dst_ptr[id] = *src0_ptr;
id++;
}
}
id += ne00 * (ne01 - ir1);
}
}
Comment on lines +779 to +796
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we merge this into the F32 branch above?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and indeed I also want to migrate some of these codes into template function. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, refactoring this code is welcome.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll merge this PR as-is and will open another PR to refactor this code

} else {
GGML_ABORT("fatal error"); // TODO: implement
}
Expand Down Expand Up @@ -947,6 +965,144 @@ static void ggml_compute_forward_dup_f32(
}
}
}
} else if (dst->type == GGML_TYPE_I32) {
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
i10 += ne00 * ir0;
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);

*(int32_t *) dst_ptr = *(const float *) src0_ptr;

if (++i10 == ne0) {
i10 = 0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
i10 += ne00 * (ne01 - ir1);
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
} else {
GGML_ABORT("fatal error"); // TODO: implement
}
}

static void ggml_compute_forward_dup_i32(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];

GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));

GGML_TENSOR_UNARY_OP_LOCALS

const int ith = params->ith; // thread index
const int nth = params->nth; // number of threads

// parallelize by rows
const int nr = ne01;
// number of rows per thread
const int dr = (nr + nth - 1) / nth;
// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);

// dst counters

int64_t i10 = 0;
int64_t i11 = 0;
int64_t i12 = 0;
int64_t i13 = 0;

// TODO: not optimal, but works
if (dst->type == GGML_TYPE_F32) {
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
i10 += ne00 * ir0;
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);

*(float *) dst_ptr = *(const int32_t *) src0_ptr;

if (++i10 == ne0) {
i10 = 0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
i10 += ne00 * (ne01 - ir1);
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
} else {
GGML_ABORT("fatal error"); // TODO: implement
}
Expand Down Expand Up @@ -1177,6 +1333,10 @@ void ggml_compute_forward_dup(
{
ggml_compute_forward_dup_f32(params, dst);
} break;
case GGML_TYPE_I32:
{
ggml_compute_forward_dup_i32(params, dst);
} break;
default:
{
if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/convert.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ template<typename dst_t, typename src_t>
return __float2bfloat16(float(x));
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
return __bfloat162float(x);
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
return int32_t(x);
} else {
return float(x);
}
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/cpy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3458,6 +3458,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) {
return true;
}
if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
return true;
}
Expand Down
15 changes: 15 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
GGML_METAL_KERNEL_TYPE_CPY_F32_I32,
GGML_METAL_KERNEL_TYPE_CPY_I32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
Expand Down Expand Up @@ -1614,6 +1616,8 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_I32, cpy_f32_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_I32_F32, cpy_i32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
Expand Down Expand Up @@ -1940,6 +1944,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_I32:
return true;
default:
return false;
Expand Down Expand Up @@ -1972,6 +1977,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
default:
return false;
}
case GGML_TYPE_I32:
return op->type == GGML_TYPE_F32;
default:
return false;
};
Expand Down Expand Up @@ -5674,6 +5681,7 @@ static int ggml_metal_encode_node(

switch (dstt) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_I32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
Expand All @@ -5685,6 +5693,13 @@ static int ggml_metal_encode_node(
default: GGML_ABORT("not implemented");
};
} break;
case GGML_TYPE_I32:
{
switch (dstt) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_I32_F32].pipeline; break;
default: GGML_ABORT("not implemented");
};
} break;
case GGML_TYPE_F16:
{
switch (dstt) {
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -5338,6 +5338,8 @@ typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;

template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy<float, int32_t>;
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy<int32_t, float>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
#endif
Expand Down
Loading
Loading