Skip to content

Commit 9fdf8ad

Browse files
committed
add constexpr and static assert
1 parent 618708c commit 9fdf8ad

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

ggml/src/ggml-cuda/concat.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
124124
uint64_t nb1,
125125
uint64_t nb2,
126126
uint64_t nb3){
127+
static_assert(dim >= 0 && dim <= 3);
128+
127129
const int64_t i3 = blockIdx.z;
128130
const int64_t i2 = blockIdx.y;
129131
const int64_t i1 = blockIdx.x;
@@ -134,13 +136,13 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
134136
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
135137
x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
136138
} else {
137-
if /*constexpr*/ (dim == 0) {
139+
if constexpr (dim == 0) {
138140
x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10);
139-
} else if (dim == 1) {
141+
} else if constexpr (dim == 1) {
140142
x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10);
141-
} else if (dim == 2) {
143+
} else if constexpr (dim == 2) {
142144
x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10);
143-
} else if (dim == 3) {
145+
} else if constexpr (dim == 3) {
144146
x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10);
145147
}
146148
}

0 commit comments

Comments
 (0)