Skip to content

Commit 90fd992

Browse files
author
bssrdf
committed
minor tweak
1 parent c36b70b commit 90fd992

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

ggml/src/ggml-cuda/cpy.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,14 +306,18 @@ static void ggml_cpy_flt_cuda(
306306
// printf("c %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03);
307307
// printf("d %zu, %zu, %zu, %zu, \n", nb10, nb11, nb12, nb13);
308308
// GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
309-
if( nb00 < nb02 && nb02 < nb03) {
309+
if( nb00 < nb02 && nb02 <= nb03 ) {
310+
// printf("a %zu, %zu, %zu, %zu, \n", ne, ne00, ne01, ne02);
311+
// printf("c %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03);
310312
dim3 dimGrid( (ne01 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
311313
(ne00 + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
312314
(ne/(ne01*ne00) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
313315
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
314316
cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
315317
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
316318
} else{
319+
// printf("b %zu, %zu, %zu, %zu, \n", ne, ne00, ne01, ne02);
320+
// printf("d %zu, %zu, %zu, %zu, \n", nb00, nb01, nb02, nb03);
317321
std::vector<std::tuple<int, int, int>> v;
318322
v.emplace_back(std::make_tuple(nb00, ne00, 0));
319323
v.emplace_back(std::make_tuple(nb01, ne01, 1));
@@ -535,7 +539,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
535539

536540
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
537541
const bool can_be_transposed = src0->op == GGML_OP_TRANSPOSE && !ggml_is_contiguous(src0) &&
538-
(src0->ne[3] == 1 || (src0->nb[2] < src0->nb[3] && src0->nb[0] < src0->nb[2]));
542+
(src0->ne[3] == 1 || (src0->nb[2] <= src0->nb[3] && src0->nb[0] < src0->nb[2]));
539543

540544
if (src0->type == src1->type && contiguous_srcs) {
541545
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));

0 commit comments

Comments
 (0)