Skip to content

Commit 09c7c50

Browse files
PhylliidaPhylliida0cc4mam17anggerganov
authored
ggml : add circular tiling support to pad, for Vulkan, CUDA, and CPU (used for making seamless textures) (ggml-org#16985)
* Feat: Added vulkan circular tiling support * Feat: Added cpu circular * Feat: Added cuda kernels * Added tests * Added tests * Removed non-pad operations * Removed unneded changes * removed backend non pad tests * Update test-backend-ops.cpp * Fixed comment on pad test * removed trailing whitespace * Removed unneded test in test-backend-ops * Removed removed test from calls * Update ggml/src/ggml-vulkan/vulkan-shaders/pad.comp Co-authored-by: Ruben Ortlam <[email protected]> * Fixed alignment * Formatting Co-authored-by: Aman Gupta <[email protected]> * Format pad * Format * Clang format * format * format * don't change so much stuff * clang format and update to bool * fix duplicates * don't need to fix the padding * make circular bool * duplicate again * rename vulkan to wrap around * Don't need indent * moved to const expr * removed unneded extra line break * More readable method calls * Minor wording changes * Added final newline * Update ggml/include/ggml.h Co-authored-by: Georgi Gerganov <[email protected]> * Update ggml/include/ggml.h Co-authored-by: Georgi Gerganov <[email protected]> * Added circular pad ext tests * Gate non circular pad devices * Cleaned gating of non-circular pad devices --------- Co-authored-by: Phylliida <[email protected]> Co-authored-by: Ruben Ortlam <[email protected]> Co-authored-by: Aman Gupta <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent f334b79 commit 09c7c50

File tree

11 files changed

+213
-59
lines changed

11 files changed

+213
-59
lines changed

ggml/include/ggml.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2196,6 +2196,15 @@ extern "C" {
21962196
int p2,
21972197
int p3);
21982198

2199+
// pad each dimension with values on the other side of the torus (looping around)
2200+
GGML_API struct ggml_tensor * ggml_pad_circular(
2201+
struct ggml_context * ctx,
2202+
struct ggml_tensor * a,
2203+
int p0,
2204+
int p1,
2205+
int p2,
2206+
int p3);
2207+
21992208
GGML_API struct ggml_tensor * ggml_pad_ext(
22002209
struct ggml_context * ctx,
22012210
struct ggml_tensor * a,
@@ -2209,6 +2218,19 @@ extern "C" {
22092218
int rp3
22102219
);
22112220

2221+
// pad each dimension with values on the other side of the torus (looping around)
2222+
GGML_API struct ggml_tensor * ggml_pad_ext_circular(
2223+
struct ggml_context * ctx,
2224+
struct ggml_tensor * a,
2225+
int lp0,
2226+
int rp0,
2227+
int lp1,
2228+
int rp1,
2229+
int lp2,
2230+
int rp2,
2231+
int lp3,
2232+
int rp3);
2233+
22122234
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
22132235
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
22142236
struct ggml_context * ctx,

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2551,6 +2551,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
25512551
case GGML_OP_ACC:
25522552
case GGML_OP_GROUP_NORM:
25532553
case GGML_OP_PAD:
2554+
// TODO: add circular padding support for cann, see https://github.com/ggml-org/llama.cpp/pull/16985
2555+
return ggml_get_op_params_i32(op, 8) == 0;
25542556
case GGML_OP_ARANGE:
25552557
case GGML_OP_TIMESTEP_EMBEDDING:
25562558
case GGML_OP_LEAKY_RELU:

ggml/src/ggml-cpu/ops.cpp

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6554,8 +6554,13 @@ static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params
65546554
ggml_compute_forward_mul_mat(params, &dst);
65556555
}
65566556

6557+
static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
6558+
return (coord + size) % size; // adding size avoids negative number weirdness
6559+
}
6560+
65576561
// ggml_compute_forward_conv_2d
65586562

6563+
65596564
static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
65606565
const ggml_tensor * kernel, // [KW, KH, IC, OC]
65616566
const ggml_tensor * src, // [W, H, C, N]
@@ -7591,6 +7596,7 @@ void ggml_compute_forward_upscale(
75917596

75927597
// ggml_compute_forward_pad
75937598

7599+
template<bool circular_t>
75947600
static void ggml_compute_forward_pad_f32(
75957601
const ggml_compute_params * params,
75967602
ggml_tensor * dst) {
@@ -7615,40 +7621,61 @@ static void ggml_compute_forward_pad_f32(
76157621
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
76167622
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
76177623

7618-
76197624
// TODO: optimize
76207625

76217626
for (int64_t i2 = 0; i2 < ne2; ++i2) {
76227627
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
76237628
for (int64_t i0 = 0; i0 < ne0; ++i0) {
76247629
for (int64_t i3 = 0; i3 < ne3; ++i3) {
7625-
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7626-
if ((i0 >= lp0 && i0 < ne0 - rp0) \
7627-
&& (i1 >= lp1 && i1 < ne1 - rp1) \
7628-
&& (i2 >= lp2 && i2 < ne2 - rp2) \
7629-
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
7630-
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7630+
// circular means wrap around on a torus, so x and y loop around
7631+
if constexpr (circular_t) {
7632+
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7633+
const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
7634+
const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
7635+
const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
7636+
const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
7637+
7638+
const int64_t src_idx =
7639+
src_i3*nb03 +
7640+
src_i2*nb02 +
7641+
src_i1*nb01 +
7642+
src_i0*nb00;
7643+
76317644
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
76327645
dst_ptr[dst_idx] = *src_ptr;
76337646
} else {
7634-
dst_ptr[dst_idx] = 0;
7647+
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7648+
if ((i0 >= lp0 && i0 < ne0 - rp0) \
7649+
&& (i1 >= lp1 && i1 < ne1 - rp1) \
7650+
&& (i2 >= lp2 && i2 < ne2 - rp2) \
7651+
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
7652+
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7653+
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7654+
dst_ptr[dst_idx] = *src_ptr;
7655+
} else {
7656+
dst_ptr[dst_idx] = 0;
7657+
}
76357658
}
76367659
}
76377660
}
76387661
}
76397662
}
76407663
}
76417664

7665+
76427666
void ggml_compute_forward_pad(
76437667
const ggml_compute_params * params,
76447668
ggml_tensor * dst) {
7645-
76467669
const ggml_tensor * src0 = dst->src[0];
7647-
7670+
const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
76487671
switch (src0->type) {
76497672
case GGML_TYPE_F32:
76507673
{
7651-
ggml_compute_forward_pad_f32(params, dst);
7674+
if (circular) {
7675+
ggml_compute_forward_pad_f32<true>(params, dst);
7676+
} else {
7677+
ggml_compute_forward_pad_f32<false>(params, dst);
7678+
}
76527679
} break;
76537680
default:
76547681
{

ggml/src/ggml-cuda/pad.cu

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
#include "pad.cuh"
22

3+
#include <stdint.h>
4+
5+
__device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) {
6+
// + size ensures negatives are handled properly
7+
return (coord + size) % size;
8+
}
9+
310
static __global__ void pad_f32(const float * src, float * dst,
411
const int lp0, const int rp0, const int lp1, const int rp1,
512
const int lp2, const int rp2, const int lp3, const int rp3,
6-
const int ne0, const int ne1, const int ne2, const int ne3) {
13+
const int ne0, const int ne1, const int ne2, const int ne3,
14+
const bool circular) {
715
// blockIdx.z: i3*ne2+i2
816
// blockIdx.y: i1
917
// blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
@@ -12,61 +20,84 @@ static __global__ void pad_f32(const float * src, float * dst,
1220
int i1 = blockIdx.y;
1321
int i2 = blockIdx.z % ne2;
1422
int i3 = blockIdx.z / ne2;
23+
1524
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
1625
return;
1726
}
1827

19-
// operation
20-
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
21-
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
22-
(i1 >= lp1 && i1 < ne1 - rp1) &&
23-
(i2 >= lp2 && i2 < ne2 - rp2) &&
24-
(i3 >= lp3 && i3 < ne3 - rp3)) {
25-
const int64_t i00 = i0 - lp0;
26-
const int64_t i01 = i1 - lp1;
27-
const int64_t i02 = i2 - lp2;
28-
const int64_t i03 = i3 - lp3;
29-
const int64_t ne02 = ne2 - lp2 - rp2;
30-
const int64_t ne01 = ne1 - lp1 - rp1;
28+
const int64_t dst_idx = i3 * (ne0 * ne1 * ne2) + i2 * (ne0 * ne1) + i1 * ne0 + i0;
29+
30+
if (!circular) {
31+
if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && (i2 >= lp2 && i2 < ne2 - rp2) &&
32+
(i3 >= lp3 && i3 < ne3 - rp3)) {
33+
const int64_t i00 = i0 - lp0;
34+
const int64_t i01 = i1 - lp1;
35+
const int64_t i02 = i2 - lp2;
36+
const int64_t i03 = i3 - lp3;
37+
const int64_t ne02 = ne2 - lp2 - rp2;
38+
const int64_t ne01 = ne1 - lp1 - rp1;
39+
const int64_t ne00 = ne0 - lp0 - rp0;
40+
41+
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
42+
43+
dst[dst_idx] = src[src_idx];
44+
} else {
45+
dst[dst_idx] = 0.0f;
46+
}
47+
}
48+
// circular means on a torus, so x and y wrap around
49+
else {
3150
const int64_t ne00 = ne0 - lp0 - rp0;
51+
const int64_t ne01 = ne1 - lp1 - rp1;
52+
const int64_t ne02 = ne2 - lp2 - rp2;
53+
const int64_t ne03 = ne3 - lp3 - rp3;
54+
55+
const int64_t i00 = wrap_around(i0 - lp0, ne00);
56+
const int64_t i01 = wrap_around(i1 - lp1, ne01);
57+
const int64_t i02 = wrap_around(i2 - lp2, ne02);
58+
const int64_t i03 = wrap_around(i3 - lp3, ne03);
3259

33-
const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00;
60+
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
3461

3562
dst[dst_idx] = src[src_idx];
36-
} else {
37-
dst[dst_idx] = 0.0f;
3863
}
3964
}
4065

66+
4167
static void pad_f32_cuda(const float * src, float * dst,
4268
const int lp0, const int rp0, const int lp1, const int rp1,
4369
const int lp2, const int rp2, const int lp3, const int rp3,
44-
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
45-
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
46-
dim3 gridDim(num_blocks, ne1, ne2*ne3);
47-
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3);
70+
const int ne0, const int ne1, const int ne2, const int ne3,
71+
const bool circular, cudaStream_t stream) {
72+
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
73+
dim3 gridDim(num_blocks, ne1, ne2 * ne3);
74+
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst,
75+
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
76+
ne0, ne1, ne2, ne3, circular);
4877
}
4978

5079
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
51-
const ggml_tensor * src0 = dst->src[0];
52-
const float * src0_d = (const float *)src0->data;
53-
float * dst_d = (float *)dst->data;
54-
cudaStream_t stream = ctx.stream();
80+
const ggml_tensor * src0 = dst->src[0];
81+
const float * src0_d = (const float *) src0->data;
82+
float * dst_d = (float *) dst->data;
83+
cudaStream_t stream = ctx.stream();
5584

5685
GGML_ASSERT(src0->type == GGML_TYPE_F32);
5786
GGML_ASSERT(dst->type == GGML_TYPE_F32);
5887
GGML_ASSERT(ggml_is_contiguous(src0));
5988

60-
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
61-
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
62-
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
63-
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
64-
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
65-
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
66-
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
67-
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
89+
const int32_t lp0 = ((const int32_t *) (dst->op_params))[0];
90+
const int32_t rp0 = ((const int32_t *) (dst->op_params))[1];
91+
const int32_t lp1 = ((const int32_t *) (dst->op_params))[2];
92+
const int32_t rp1 = ((const int32_t *) (dst->op_params))[3];
93+
const int32_t lp2 = ((const int32_t *) (dst->op_params))[4];
94+
const int32_t rp2 = ((const int32_t *) (dst->op_params))[5];
95+
const int32_t lp3 = ((const int32_t *) (dst->op_params))[6];
96+
const int32_t rp3 = ((const int32_t *) (dst->op_params))[7];
97+
const int32_t circular = ((const int32_t *) (dst->op_params))[8];
6898

6999
pad_f32_cuda(src0_d, dst_d,
70100
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
71-
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
101+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
102+
(bool) circular, stream);
72103
}

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
10371037
case GGML_OP_POOL_2D:
10381038
return op->src[0]->type == GGML_TYPE_F32;
10391039
case GGML_OP_PAD:
1040+
// TODO: add circular padding support for metal, see https://github.com/ggml-org/llama.cpp/pull/16985
1041+
if (ggml_get_op_params_i32(op, 8) != 0) {
1042+
return false;
1043+
}
1044+
10401045
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
10411046
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
10421047
case GGML_OP_PAD_REFLECT_1D:

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3083,6 +3083,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
30833083
case GGML_OP_REPEAT:
30843084
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
30853085
case GGML_OP_PAD:
3086+
// TODO: add circular padding support for opencl, see https://github.com/ggml-org/llama.cpp/pull/16985
3087+
if (ggml_get_op_params_i32(op, 8) != 0) {
3088+
return false;
3089+
}
30863090
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
30873091
case GGML_OP_UPSCALE: {
30883092
ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF);

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4613,6 +4613,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
46134613
case GGML_OP_ACC:
46144614
return true;
46154615
case GGML_OP_PAD:
4616+
// TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985
4617+
if (ggml_get_op_params_i32(op, 8) != 0) {
4618+
return false;
4619+
}
46164620
return ggml_is_contiguous(op->src[0]);
46174621
case GGML_OP_LEAKY_RELU:
46184622
case GGML_OP_TIMESTEP_EMBEDDING:

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,7 @@ struct vk_op_pad_push_constants {
10501050
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
10511051
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
10521052
uint32_t misalign_offsets;
1053+
uint32_t circular;
10531054

10541055
uint32_t lp0; uint32_t rp0;
10551056
uint32_t lp1; uint32_t rp1;
@@ -1092,6 +1093,7 @@ static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor
10921093
p.rp2 = dst->op_params[5];
10931094
p.lp3 = dst->op_params[6];
10941095
p.rp3 = dst->op_params[7];
1096+
p.circular = dst->op_params[8];
10951097

10961098
return p; // fastdiv values and offsets are initialized later in ggml_vk_op
10971099
}

ggml/src/ggml-vulkan/vulkan-shaders/pad.comp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ layout (push_constant) uniform parameter
88
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
99
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
1010
uint misalign_offsets;
11+
uint circular;
1112

1213
uint lp0; uint rp0;
1314
uint lp1; uint rp1;
@@ -18,6 +19,10 @@ layout (push_constant) uniform parameter
1819
uint get_aoffset() { return p.misalign_offsets >> 16; }
1920
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
2021

22+
uint wrap_around(int coord, uint size) {
23+
return (uint(coord + int(size))) % size; // add size to avoid issues with negative
24+
}
25+
2126
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
2227
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
2328

@@ -40,10 +45,20 @@ void main() {
4045
const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00;
4146
const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;
4247

43-
const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 &&
44-
i1 >= p.lp1 && i1 < p.ne11 - p.rp1 &&
45-
i2 >= p.lp2 && i2 < p.ne12 - p.rp2 &&
46-
i3 >= p.lp3 && i3 < p.ne13 - p.rp3;
48+
if (p.circular != 0u) {
49+
const uint ci0 = wrap_around(int(i0) - int(p.lp0), p.ne00);
50+
const uint ci1 = wrap_around(int(i1) - int(p.lp1), p.ne01);
51+
const uint ci2 = wrap_around(int(i2) - int(p.lp2), p.ne02);
52+
const uint ci3 = wrap_around(int(i3) - int(p.lp3), p.ne03);
53+
const uint circular_src_idx = ci3*p.nb03 + ci2*p.nb02 + ci1*p.nb01 + ci0*p.nb00;
54+
data_d[get_doffset() + dst_idx] = D_TYPE(data_a[get_aoffset() + circular_src_idx]);
55+
} else {
56+
const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 &&
57+
i1 >= p.lp1 && i1 < p.ne11 - p.rp1 &&
58+
i2 >= p.lp2 && i2 < p.ne12 - p.rp2 &&
59+
i3 >= p.lp3 && i3 < p.ne13 - p.rp3;
60+
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
61+
}
62+
4763

48-
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
4964
}

0 commit comments

Comments
 (0)