Skip to content

Commit f3b489b

Browse files
committed
ggml: allow casting between f32 and i32
1 parent dff7551 commit f3b489b

File tree

8 files changed

+184
-1
lines changed

8 files changed

+184
-1
lines changed

ggml/include/ggml-cpu.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ extern "C" {
135135
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
136136

137137
GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
138+
GGML_BACKEND_API void ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t);
138139
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
139140
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
140141
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,9 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
373373
.vec_dot_type = GGML_TYPE_Q8_K,
374374
.nrows = 1,
375375
},
376+
[GGML_TYPE_I32] = {
377+
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32,
378+
},
376379
};
377380

378381
const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
@@ -2691,7 +2694,10 @@ struct ggml_cplan ggml_graph_plan(
26912694
if (ggml_is_quantized(node->type) ||
26922695
// F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
26932696
(node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
2694-
(node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
2697+
(node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16) ||
2698+
// conversion between F32 and I32
2699+
(node->src[0]->type == GGML_TYPE_F32 && node->src[1] && node->src[1]->type == GGML_TYPE_I32) ||
2700+
(node->src[0]->type == GGML_TYPE_I32 && node->src[1] && node->src[1]->type == GGML_TYPE_F32)) {
26952701
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
26962702
}
26972703
} break;
@@ -3283,6 +3289,13 @@ void ggml_cpu_fp32_to_bf16(const float * x, ggml_bf16_t * y, int64_t n) {
32833289
}
32843290
}
32853291

3292+
void ggml_cpu_fp32_to_i32(const float * x, int32_t * y, int64_t n) {
3293+
int64_t i = 0;
3294+
for (; i < n; ++i) {
3295+
y[i] = x[i];
3296+
}
3297+
}
3298+
32863299
void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) {
32873300
int64_t i = 0;
32883301
#if defined(__AVX2__)

ggml/src/ggml-cpu/ops.cpp

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,144 @@ static void ggml_compute_forward_dup_f32(
947947
}
948948
}
949949
}
950+
} else if (dst->type == GGML_TYPE_I32) {
951+
for (int64_t i03 = 0; i03 < ne03; i03++) {
952+
for (int64_t i02 = 0; i02 < ne02; i02++) {
953+
i10 += ne00 * ir0;
954+
while (i10 >= ne0) {
955+
i10 -= ne0;
956+
if (++i11 == ne1) {
957+
i11 = 0;
958+
if (++i12 == ne2) {
959+
i12 = 0;
960+
if (++i13 == ne3) {
961+
i13 = 0;
962+
}
963+
}
964+
}
965+
}
966+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
967+
for (int64_t i00 = 0; i00 < ne00; i00++) {
968+
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
969+
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
970+
971+
*(int32_t *) dst_ptr = *(const float *) src0_ptr;
972+
973+
if (++i10 == ne0) {
974+
i10 = 0;
975+
if (++i11 == ne1) {
976+
i11 = 0;
977+
if (++i12 == ne2) {
978+
i12 = 0;
979+
if (++i13 == ne3) {
980+
i13 = 0;
981+
}
982+
}
983+
}
984+
}
985+
}
986+
}
987+
i10 += ne00 * (ne01 - ir1);
988+
while (i10 >= ne0) {
989+
i10 -= ne0;
990+
if (++i11 == ne1) {
991+
i11 = 0;
992+
if (++i12 == ne2) {
993+
i12 = 0;
994+
if (++i13 == ne3) {
995+
i13 = 0;
996+
}
997+
}
998+
}
999+
}
1000+
}
1001+
}
1002+
} else {
1003+
GGML_ABORT("fatal error"); // TODO: implement
1004+
}
1005+
}
1006+
1007+
static void ggml_compute_forward_dup_i32(
1008+
const ggml_compute_params * params,
1009+
ggml_tensor * dst) {
1010+
1011+
const ggml_tensor * src0 = dst->src[0];
1012+
1013+
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
1014+
1015+
GGML_TENSOR_UNARY_OP_LOCALS
1016+
1017+
const int ith = params->ith; // thread index
1018+
const int nth = params->nth; // number of threads
1019+
1020+
// parallelize by rows
1021+
const int nr = ne01;
1022+
// number of rows per thread
1023+
const int dr = (nr + nth - 1) / nth;
1024+
// row range for this thread
1025+
const int ir0 = dr * ith;
1026+
const int ir1 = MIN(ir0 + dr, nr);
1027+
1028+
// dst counters
1029+
1030+
int64_t i10 = 0;
1031+
int64_t i11 = 0;
1032+
int64_t i12 = 0;
1033+
int64_t i13 = 0;
1034+
1035+
// TODO: not optimal, but works
1036+
if (dst->type == GGML_TYPE_F32) {
1037+
for (int64_t i03 = 0; i03 < ne03; i03++) {
1038+
for (int64_t i02 = 0; i02 < ne02; i02++) {
1039+
i10 += ne00 * ir0;
1040+
while (i10 >= ne0) {
1041+
i10 -= ne0;
1042+
if (++i11 == ne1) {
1043+
i11 = 0;
1044+
if (++i12 == ne2) {
1045+
i12 = 0;
1046+
if (++i13 == ne3) {
1047+
i13 = 0;
1048+
}
1049+
}
1050+
}
1051+
}
1052+
for (int64_t i01 = ir0; i01 < ir1; i01++) {
1053+
for (int64_t i00 = 0; i00 < ne00; i00++) {
1054+
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
1055+
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
1056+
1057+
*(float *) dst_ptr = *(const int32_t *) src0_ptr;
1058+
1059+
if (++i10 == ne0) {
1060+
i10 = 0;
1061+
if (++i11 == ne1) {
1062+
i11 = 0;
1063+
if (++i12 == ne2) {
1064+
i12 = 0;
1065+
if (++i13 == ne3) {
1066+
i13 = 0;
1067+
}
1068+
}
1069+
}
1070+
}
1071+
}
1072+
}
1073+
i10 += ne00 * (ne01 - ir1);
1074+
while (i10 >= ne0) {
1075+
i10 -= ne0;
1076+
if (++i11 == ne1) {
1077+
i11 = 0;
1078+
if (++i12 == ne2) {
1079+
i12 = 0;
1080+
if (++i13 == ne3) {
1081+
i13 = 0;
1082+
}
1083+
}
1084+
}
1085+
}
1086+
}
1087+
}
9501088
} else {
9511089
GGML_ABORT("fatal error"); // TODO: implement
9521090
}
@@ -1177,6 +1315,10 @@ void ggml_compute_forward_dup(
11771315
{
11781316
ggml_compute_forward_dup_f32(params, dst);
11791317
} break;
1318+
case GGML_TYPE_I32:
1319+
{
1320+
ggml_compute_forward_dup_i32(params, dst);
1321+
} break;
11801322
default:
11811323
{
11821324
if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {

ggml/src/ggml-cuda/convert.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ template<typename dst_t, typename src_t>
3838
return __float2bfloat16(float(x));
3939
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
4040
return __bfloat162float(x);
41+
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
42+
return __float2int_rd(x);
43+
} else if constexpr(std::is_same_v<src_t, int32_t>) {
44+
return __int2float_rd(x);
4145
} else {
4246
return float(x);
4347
}

ggml/src/ggml-cuda/cpy.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
374374
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
375375
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
376376
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
377+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
378+
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
379+
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
380+
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
377381
} else {
378382
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
379383
ggml_type_name(src0->type), ggml_type_name(src1->type));

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
582582
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
583583
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
584584
GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
585+
GGML_METAL_KERNEL_TYPE_CPY_F32_I32,
586+
GGML_METAL_KERNEL_TYPE_CPY_I32_F32,
585587
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
586588
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
587589
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
@@ -1614,6 +1616,8 @@ @implementation GGMLMetalClass
16141616
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
16151617
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
16161618
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
1619+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_I32, cpy_f32_i32, true);
1620+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_I32_F32, cpy_i32_f32, true);
16171621
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
16181622
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
16191623
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
@@ -1940,6 +1944,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
19401944
case GGML_TYPE_Q5_0:
19411945
case GGML_TYPE_Q5_1:
19421946
case GGML_TYPE_IQ4_NL:
1947+
case GGML_TYPE_I32:
19431948
return true;
19441949
default:
19451950
return false;
@@ -1972,6 +1977,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
19721977
default:
19731978
return false;
19741979
}
1980+
case GGML_TYPE_I32:
1981+
return op->type == GGML_TYPE_F32;
19751982
default:
19761983
return false;
19771984
};
@@ -5674,6 +5681,7 @@ static int ggml_metal_encode_node(
56745681

56755682
switch (dstt) {
56765683
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
5684+
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_I32].pipeline; break;
56775685
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
56785686
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
56795687
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
@@ -5685,6 +5693,13 @@ static int ggml_metal_encode_node(
56855693
default: GGML_ABORT("not implemented");
56865694
};
56875695
} break;
5696+
case GGML_TYPE_I32:
5697+
{
5698+
switch (dstt) {
5699+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_I32_F32].pipeline; break;
5700+
default: GGML_ABORT("not implemented");
5701+
};
5702+
} break;
56885703
case GGML_TYPE_F16:
56895704
{
56905705
switch (dstt) {

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5338,6 +5338,8 @@ typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
53385338

53395339
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
53405340
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
5341+
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy<float, int32_t>;
5342+
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy<int32_t, float>;
53415343
#if defined(GGML_METAL_USE_BF16)
53425344
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
53435345
#endif

tests/test-backend-ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5869,6 +5869,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
58695869
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous
58705870
}
58715871
}
5872+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}));
5873+
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}));
58725874

58735875
test_cases.emplace_back(new test_cont());
58745876
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));

0 commit comments

Comments
 (0)