Skip to content

Commit 4718203

Browse files
committed
Merge branch 'master' into xsn/lighton-ocr
2 parents 62cc684 + 945501f commit 4718203

35 files changed

+1950
-439
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
8484
- [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1)
8585
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
8686
- [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct)
87+
- [x] [Jamba](https://huggingface.co/ai21labs)
8788
- [X] [Falcon](https://huggingface.co/models?search=tiiuae/falcon)
8889
- [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)
8990
- [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne)

convert_hf_to_gguf.py

Lines changed: 235 additions & 82 deletions
Large diffs are not rendered by default.

examples/model-conversion/scripts/causal/run-org-model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def fn(_m, input, output):
138138
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
139139
)
140140

141-
config = AutoConfig.from_pretrained(model_path)
141+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
142142

143143
print("Model type: ", config.model_type)
144144
print("Vocab size: ", config.vocab_size)
@@ -148,8 +148,8 @@ def fn(_m, input, output):
148148
print("EOS token id: ", config.eos_token_id)
149149

150150
print("Loading model and tokenizer using AutoTokenizer:", model_path)
151-
tokenizer = AutoTokenizer.from_pretrained(model_path)
152-
config = AutoConfig.from_pretrained(model_path)
151+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
152+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
153153

154154
if unreleased_model_name:
155155
model_name_lower = unreleased_model_name.lower()
@@ -171,7 +171,7 @@ def fn(_m, input, output):
171171
exit(1)
172172
else:
173173
model = AutoModelForCausalLM.from_pretrained(
174-
model_path, device_map="auto", offload_folder="offload"
174+
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True
175175
)
176176

177177
for name, module in model.named_modules():

ggml/src/ggml-alloc.c

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,16 +226,23 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al
226226
}
227227

228228
if (best_fit_block == -1) {
229-
// no suitable block found, try the last block (this will grow a chunks size)
229+
// no suitable block found, try the last block (this may grow a chunks size)
230+
int64_t best_reuse = INT64_MIN;
230231
for (int c = 0; c < alloc->n_chunks; ++c) {
231232
struct tallocr_chunk * chunk = alloc->chunks[c];
232233
if (chunk->n_free_blocks > 0) {
233234
struct free_block * block = &chunk->free_blocks[chunk->n_free_blocks - 1];
234235
max_avail = MAX(max_avail, block->size);
235-
if (block->size >= size) {
236+
int64_t reuse_factor = chunk->max_size - block->offset - size;
237+
// reuse_factor < 0 : amount of extra memory that needs to be allocated
238+
// reuse_factor = 0 : allocated free space exactly matches tensor size
239+
// reuse_factor > 0 : superfluous memory that will remain unused
240+
bool better_reuse = best_reuse < 0 && reuse_factor > best_reuse;
241+
bool better_fit = reuse_factor >= 0 && reuse_factor < best_reuse;
242+
if (block->size >= size && (better_reuse || better_fit)) {
236243
best_fit_chunk = c;
237244
best_fit_block = chunk->n_free_blocks - 1;
238-
break;
245+
best_reuse = reuse_factor;
239246
}
240247
}
241248
}
@@ -268,7 +275,7 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al
268275
#ifdef GGML_ALLOCATOR_DEBUG
269276
add_allocated_tensor(alloc, addr, tensor);
270277
size_t cur_max = addr.offset + size;
271-
if (cur_max > alloc->max_size[addr.chunk]) {
278+
if (cur_max > chunk->max_size) {
272279
// sort allocated_tensors by chunk/offset
273280
for (int i = 0; i < 1024; i++) {
274281
for (int j = i + 1; j < 1024; j++) {

ggml/src/ggml-cuda/argsort.cu

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,81 @@
11
#include "argsort.cuh"
22

3+
#ifdef GGML_CUDA_USE_CUB
4+
# include <cub/cub.cuh>
5+
using namespace cub;
6+
#endif // GGML_CUDA_USE_CUB
7+
8+
static __global__ void init_indices(int * indices, const int ncols, const int nrows) {
9+
const int col = blockIdx.x * blockDim.x + threadIdx.x;
10+
const int row = blockIdx.y;
11+
12+
if (col < ncols && row < nrows) {
13+
indices[row * ncols + col] = col;
14+
}
15+
}
16+
17+
static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
18+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
19+
if (idx <= nrows) {
20+
offsets[idx] = idx * ncols;
21+
}
22+
}
23+
24+
#ifdef GGML_CUDA_USE_CUB
25+
static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
26+
const float * x,
27+
int * dst,
28+
const int ncols,
29+
const int nrows,
30+
ggml_sort_order order,
31+
cudaStream_t stream) {
32+
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
33+
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
34+
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
35+
36+
int * temp_indices = temp_indices_alloc.get();
37+
float * temp_keys = temp_keys_alloc.get();
38+
int * d_offsets = offsets_alloc.get();
39+
40+
static const int block_size = 256;
41+
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
42+
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
43+
44+
const dim3 offset_grid((nrows + block_size - 1) / block_size);
45+
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
46+
47+
cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);
48+
49+
size_t temp_storage_bytes = 0;
50+
51+
if (order == GGML_SORT_ORDER_ASC) {
52+
DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
53+
temp_indices, dst, // values (indices)
54+
ncols * nrows, nrows, // num items, num segments
55+
d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
56+
stream);
57+
} else {
58+
DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
59+
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
60+
sizeof(float) * 8, stream);
61+
}
62+
63+
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
64+
void * d_temp_storage = temp_storage_alloc.get();
65+
66+
if (order == GGML_SORT_ORDER_ASC) {
67+
DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
68+
ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
69+
stream);
70+
} else {
71+
DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
72+
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
73+
0, sizeof(float) * 8, stream);
74+
}
75+
}
76+
#endif // GGML_CUDA_USE_CUB
77+
78+
// Bitonic sort implementation
379
template<typename T>
480
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
581
T tmp = a;
@@ -65,7 +141,12 @@ static int next_power_of_2(int x) {
65141
return n;
66142
}
67143

68-
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
144+
static void argsort_f32_i32_cuda_bitonic(const float * x,
145+
int * dst,
146+
const int ncols,
147+
const int nrows,
148+
ggml_sort_order order,
149+
cudaStream_t stream) {
69150
// bitonic sort requires ncols to be power of 2
70151
const int ncols_pad = next_power_of_2(ncols);
71152

@@ -77,9 +158,11 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
77158
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
78159

79160
if (order == GGML_SORT_ORDER_ASC) {
80-
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
161+
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
162+
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
81163
} else if (order == GGML_SORT_ORDER_DESC) {
82-
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
164+
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
165+
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
83166
} else {
84167
GGML_ABORT("fatal error");
85168
}
@@ -100,5 +183,18 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
100183

101184
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
102185

103-
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
186+
#ifdef GGML_CUDA_USE_CUB
187+
const int ncols_pad = next_power_of_2(ncols);
188+
const size_t shared_mem = ncols_pad * sizeof(int);
189+
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
190+
191+
if (shared_mem > max_shared_mem || ncols > 1024) {
192+
ggml_cuda_pool & pool = ctx.pool();
193+
argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
194+
} else {
195+
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
196+
}
197+
#else
198+
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
199+
#endif
104200
}

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
272272
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
273273
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
274274

275-
if (block_nums.z > 65535) {
275+
if (block_nums.z > 65535 || block_nums.y > 65535) {
276276
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
277277
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
278278
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));

ggml/src/ggml-cuda/common.cuh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,3 +1005,16 @@ struct ggml_backend_cuda_context {
10051005
return pool(device);
10061006
}
10071007
};
1008+
1009+
struct ggml_cuda_mm_fusion_args_host {
1010+
const ggml_tensor * x_bias = nullptr;
1011+
const ggml_tensor * gate = nullptr;
1012+
const ggml_tensor * gate_bias = nullptr;
1013+
ggml_glu_op glu_op;
1014+
};
1015+
struct ggml_cuda_mm_fusion_args_device {
1016+
const void * x_bias = nullptr;
1017+
const void * gate = nullptr;
1018+
const void * gate_bias = nullptr;
1019+
ggml_glu_op glu_op;
1020+
};

ggml/src/ggml-cuda/convert.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#pragma once
12
#include "common.cuh"
23

34
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256

ggml/src/ggml-cuda/cpy.cu

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,30 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
112112
cpy_blck(cx + x_offset, cdst + dst_offset);
113113
}
114114

115+
template<typename src_t, typename dst_t>
116+
static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) {
117+
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
118+
119+
if (i >= ne) {
120+
return;
121+
}
122+
123+
const src_t * x = (const src_t *) cx;
124+
dst_t * dst = (dst_t *) cdst;
125+
126+
dst[i] = ggml_cuda_cast<dst_t>(x[i]);
127+
}
128+
129+
template<typename src_t, typename dst_t>
130+
static void ggml_cpy_flt_contiguous_cuda(
131+
const char * cx, char * cdst, const int64_t ne,
132+
cudaStream_t stream) {
133+
134+
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
135+
cpy_flt_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
136+
(cx, cdst, ne);
137+
}
138+
115139
template<typename src_t, typename dst_t>
116140
static void ggml_cpy_flt_cuda(
117141
const char * cx, char * cdst, const int ne,
@@ -285,7 +309,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
285309
char * src0_ddc = (char *) src0->data;
286310
char * src1_ddc = (char *) src1->data;
287311

288-
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
312+
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
313+
314+
if (src0->type == src1->type && contiguous_srcs) {
289315
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
290316
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
291317
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
@@ -296,11 +322,19 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
296322
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
297323
}
298324
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
299-
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
325+
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
300326
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
301-
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
327+
if (contiguous_srcs) {
328+
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
329+
} else {
330+
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
331+
}
302332
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
303-
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
333+
if (contiguous_srcs) {
334+
ggml_cpy_flt_contiguous_cuda<float, half> (src0_ddc, src1_ddc, ne, main_stream);
335+
} else {
336+
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
337+
}
304338
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
305339
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
306340
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -327,21 +361,45 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
327361
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
328362
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
329363
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
330-
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
364+
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
331365
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
332-
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
366+
if (contiguous_srcs) {
367+
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
368+
} else {
369+
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
370+
}
333371
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
334-
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
372+
if (contiguous_srcs) {
373+
ggml_cpy_flt_contiguous_cuda<half, float> (src0_ddc, src1_ddc, ne, main_stream);
374+
} else {
375+
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
376+
}
335377
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
336378
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
337379
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
338-
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
380+
if (contiguous_srcs) {
381+
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
382+
} else {
383+
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
384+
}
339385
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
340-
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
386+
if (contiguous_srcs) {
387+
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, main_stream);
388+
} else {
389+
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
390+
}
341391
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
342-
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
392+
if (contiguous_srcs) {
393+
ggml_cpy_flt_contiguous_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, main_stream);
394+
} else {
395+
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
396+
}
343397
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
344-
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
398+
if (contiguous_srcs) {
399+
ggml_cpy_flt_contiguous_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, main_stream);
400+
} else {
401+
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
402+
}
345403
} else {
346404
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
347405
ggml_type_name(src0->type), ggml_type_name(src1->type));

0 commit comments

Comments
 (0)