Skip to content

Commit f8a5b04

Browse files
A3shTnTslaren
andauthored
Use a lambda to avoid code duplication
Co-authored-by: Diego Devesa <[email protected]>
1 parent e4189e3 commit f8a5b04

File tree

1 file changed

+16
-20
lines changed

1 file changed

+16
-20
lines changed

ggml/src/ggml-cuda/concat.cu

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -188,37 +188,33 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
188188
}
189189
} else {
190190
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
191+
auto launch_kernel = [&](auto dim) {
192+
concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
193+
(const char *) src0->data, (const char *) src1->data, (char *) dst->data,
194+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
195+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
196+
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
197+
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
198+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
199+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
200+
};
191201
switch (dim) {
192202
case 0:
193-
concat_f32_non_cont<0><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
194-
(const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1],
195-
src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
196-
src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
197-
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
203+
launch_kernel(std::integral_constant<int, 0>{});
198204
break;
199205
case 1:
200-
concat_f32_non_cont<1><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
201-
(const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1],
202-
src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
203-
src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
204-
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
206+
launch_kernel(std::integral_constant<int, 1>{});
205207
break;
206208
case 2:
207-
concat_f32_non_cont<2><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
208-
(const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1],
209-
src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
210-
src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
211-
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
209+
launch_kernel(std::integral_constant<int, 2>{});
212210
break;
213211
case 3:
214-
concat_f32_non_cont<3><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
215-
(const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1],
216-
src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
217-
src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
218-
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
212+
launch_kernel(std::integral_constant<int, 3>{});
219213
break;
220214
default:
215+
GGML_ABORT("Invalid dim: %d", dim);
221216
break;
222217
}
223218
}
219+
}
224220
}

0 commit comments

Comments
 (0)