@@ -6648,14 +6648,143 @@ static void ggml_compute_forward_repeat_back(
66486648
66496649// ggml_compute_forward_concat
66506650
6651+ static void ggml_compute_forward_concat_any(
6652+ const struct ggml_compute_params * params,
6653+ struct ggml_tensor * dst) {
6654+
6655+ const struct ggml_tensor * src0 = dst->src[0];
6656+ const struct ggml_tensor * src1 = dst->src[1];
6657+
6658+ const size_t len = ggml_type_size(src0->type);
6659+
6660+ const int ith = params->ith;
6661+ const int nth = params->nth;
6662+
6663+ GGML_TENSOR_BINARY_OP_LOCALS
6664+
6665+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
6666+
6667+ GGML_ASSERT(dim >= 0 && dim < 4);
6668+
6669+ int64_t o[4] = {0, 0, 0, 0};
6670+ o[dim] = src0->ne[dim];
6671+
6672+ const char * x;
6673+
6674+ // TODO: smarter multi-theading
6675+ for (int i3 = 0; i3 < ne3; i3++) {
6676+ for (int i2 = ith; i2 < ne2; i2 += nth) {
6677+ for (int i1 = 0; i1 < ne1; i1++) {
6678+ for (int i0 = 0; i0 < ne0; i0++) {
6679+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
6680+ x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03;
6681+ } else {
6682+ x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
6683+ }
6684+
6685+ char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
6686+
6687+ memcpy(y, x, len);
6688+ }
6689+ }
6690+ }
6691+ }
6692+ }
6693+
6694+ static void ggml_compute_forward_concat_i8(
6695+ const struct ggml_compute_params * params,
6696+ struct ggml_tensor * dst) {
6697+
6698+ const struct ggml_tensor * src0 = dst->src[0];
6699+ const struct ggml_tensor * src1 = dst->src[1];
6700+
6701+ GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
6702+
6703+ const int ith = params->ith;
6704+ const int nth = params->nth;
6705+
6706+ GGML_TENSOR_BINARY_OP_LOCALS
6707+
6708+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
6709+
6710+ GGML_ASSERT(dim >= 0 && dim < 4);
6711+
6712+ int64_t o[4] = {0, 0, 0, 0};
6713+ o[dim] = src0->ne[dim];
6714+
6715+ const int8_t * x;
6716+
6717+ // TODO: smarter multi-theading
6718+ for (int i3 = 0; i3 < ne3; i3++) {
6719+ for (int i2 = ith; i2 < ne2; i2 += nth) {
6720+ for (int i1 = 0; i1 < ne1; i1++) {
6721+ for (int i0 = 0; i0 < ne0; i0++) {
6722+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
6723+ x = (const int8_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
6724+ } else {
6725+ x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
6726+ }
6727+
6728+ int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
6729+
6730+ *y = *x;
6731+ }
6732+ }
6733+ }
6734+ }
6735+ }
6736+
6737+ static void ggml_compute_forward_concat_f16(
6738+ const struct ggml_compute_params * params,
6739+ struct ggml_tensor * dst) {
6740+
6741+ const struct ggml_tensor * src0 = dst->src[0];
6742+ const struct ggml_tensor * src1 = dst->src[1];
6743+
6744+ GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
6745+
6746+ const int ith = params->ith;
6747+ const int nth = params->nth;
6748+
6749+ GGML_TENSOR_BINARY_OP_LOCALS
6750+
6751+ const int32_t dim = ggml_get_op_params_i32(dst, 0);
6752+
6753+ GGML_ASSERT(dim >= 0 && dim < 4);
6754+
6755+ int64_t o[4] = {0, 0, 0, 0};
6756+ o[dim] = src0->ne[dim];
6757+
6758+ const ggml_fp16_t * x;
6759+
6760+ // TODO: smarter multi-theading
6761+ for (int i3 = 0; i3 < ne3; i3++) {
6762+ for (int i2 = ith; i2 < ne2; i2 += nth) {
6763+ for (int i1 = 0; i1 < ne1; i1++) {
6764+ for (int i0 = 0; i0 < ne0; i0++) {
6765+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
6766+ x = (const ggml_fp16_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
6767+ } else {
6768+ x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
6769+ }
6770+
6771+ ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
6772+
6773+ *y = *x;
6774+ }
6775+ }
6776+ }
6777+ }
6778+ }
6779+
66516780static void ggml_compute_forward_concat_f32(
66526781 const struct ggml_compute_params * params,
66536782 struct ggml_tensor * dst) {
66546783
66556784 const struct ggml_tensor * src0 = dst->src[0];
66566785 const struct ggml_tensor * src1 = dst->src[1];
66576786
6658- GGML_ASSERT(src0->nb[0] == sizeof(float));
6787+ GGML_ASSERT(ggml_type_size( src0->type) == sizeof(float));
66596788
66606789 const int ith = params->ith;
66616790 const int nth = params->nth;
@@ -6698,14 +6827,24 @@ static void ggml_compute_forward_concat(
66986827 const struct ggml_tensor * src0 = dst->src[0];
66996828
67006829 switch (src0->type) {
6830+ case GGML_TYPE_F16:
6831+ case GGML_TYPE_BF16:
6832+ case GGML_TYPE_I16:
6833+ {
6834+ ggml_compute_forward_concat_f16(params, dst);
6835+ } break;
6836+ case GGML_TYPE_I8:
6837+ {
6838+ ggml_compute_forward_concat_i8(params, dst);
6839+ } break;
67016840 case GGML_TYPE_F32:
67026841 case GGML_TYPE_I32:
67036842 {
67046843 ggml_compute_forward_concat_f32(params, dst);
67056844 } break;
67066845 default:
67076846 {
6708- GGML_ABORT("fatal error" );
6847+ ggml_compute_forward_concat_any(params, dst );
67096848 }
67106849 }
67116850}
0 commit comments