|
| 1 | +#ifndef GPUFIT_NATURAL_BSPLINE_ND_CUH_INCLUDED |
| 2 | +#define GPUFIT_NATURAL_BSPLINE_ND_CUH_INCLUDED |
| 3 | + |
| 4 | +#include "bspline_fast_cubic_basis_evaluate_device.cuh" |
| 5 | + |
| 6 | +// Max number of supported dimensions (adjust as needed) |
| 7 | +#define NATURAL_BSPLINE_ND_MAX_DIMS 6 |
| 8 | +#define CUBIC_DEGREE 3 |
| 9 | +#define CUBIC_BASIS_SIZE 4 |
| 10 | + |
| 11 | +/* |
| 12 | +* Description of the calculate_natural_bspline_nd function |
| 13 | +* ======================================================== |
| 14 | +* |
| 15 | +* parameters: [amp, center_0, ..., center_{D-1}, offset] |
| 16 | +* n_fits: number of fits (for Gpufit batch) |
| 17 | +* n_points: number of points per fit |
| 18 | +* value: output model values |
| 19 | +* derivative: output derivatives |
| 20 | +* point_index: index of the current point (for each axis, see below) |
| 21 | +* fit_index: index of the current fit |
| 22 | +* chunk_index: chunk index |
| 23 | +* user_info: buffer holding all ND spline metadata in the following order: |
| 24 | +* user_info[0] = D (number of dims) |
| 25 | +* user_info[1 ... D] = number of data points per axis |
| 26 | +* user_info[1+D ... 1+2D] = number of control points per axis |
| 27 | +* user_info[1+2D ... 1+3D] = number of knots per axis |
| 28 | +* user_info[... next] = flattened knot vectors (all axes, all knots per axis) |
| 29 | +* user_info[... next] = flattened control point strides (for flattened coeff tensor) |
| 30 | +* user_info[... next] = coefficients (flattened, length = product of control points per axis) |
| 31 | +* |
| 32 | +* For best performance, precompute and pack all these arrays on the host using your Gpuspline code. |
| 33 | +*/ |
| 34 | + |
| 35 | +__device__ void calculate_natural_bspline_nd( |
| 36 | + REAL const * parameters, // [amp, center_0..center_{D-1}, offset] |
| 37 | + int const n_fits, |
| 38 | + int const n_points, |
| 39 | + REAL * value, |
| 40 | + REAL * derivative, |
| 41 | + int const point_index, |
| 42 | + int const fit_index, |
| 43 | + int const chunk_index, |
| 44 | + char * user_info, |
| 45 | + std::size_t const user_info_size) |
| 46 | +{ |
| 47 | + // --- Unpack user_info --- |
| 48 | + REAL const * ui = (REAL*)user_info; |
| 49 | + |
| 50 | + int D = static_cast<int>(ui[0]); |
| 51 | + int d; |
| 52 | + int const * data_points = (int const *)(ui + 1); // [D] |
| 53 | + int const * control_points = data_points + D; // [D] |
| 54 | + int const * num_knots = control_points + D; // [D] |
| 55 | + |
| 56 | + int offset_knots = 1 + 3 * D; |
| 57 | + int offset_strides = offset_knots; |
| 58 | + for (d = 0; d < D; ++d) offset_strides += num_knots[d]; |
| 59 | + int offset_coeff = offset_strides + D; |
| 60 | + |
| 61 | + // --- Pointers to ND arrays --- |
| 62 | + REAL const * knots[NATURAL_BSPLINE_ND_MAX_DIMS]; |
| 63 | + int n_ctrl[NATURAL_BSPLINE_ND_MAX_DIMS]; |
| 64 | + int n_stride[NATURAL_BSPLINE_ND_MAX_DIMS]; |
| 65 | + int n_span[NATURAL_BSPLINE_ND_MAX_DIMS]; |
| 66 | + |
| 67 | + int acc = offset_knots; |
| 68 | + for (d = 0; d < D; ++d) |
| 69 | + { |
| 70 | + knots[d] = ui + acc; |
| 71 | + acc += num_knots[d]; |
| 72 | + n_ctrl[d] = control_points[d]; |
| 73 | + n_stride[d] = static_cast<int>(ui[offset_strides + d]); |
| 74 | + } |
| 75 | + REAL const * coeff = ui + offset_coeff; |
| 76 | + |
| 77 | + // --- Map point_index to ND coordinates (x[0], ..., x[D-1]) --- |
| 78 | + // For image or ND array, this is usually unravel_index(point_index, data_points) |
| 79 | + int coords[NATURAL_BSPLINE_ND_MAX_DIMS] = { 0 }; |
| 80 | + int idx = point_index; |
| 81 | + for (d = 0; d < D; ++d) |
| 82 | + { |
| 83 | + coords[d] = idx % data_points[d]; |
| 84 | + idx /= data_points[d]; |
| 85 | + } |
| 86 | + |
| 87 | + // --- Unpack model parameters --- |
| 88 | + REAL amp = parameters[0]; |
| 89 | + REAL center[NATURAL_BSPLINE_ND_MAX_DIMS]; |
| 90 | + for (d = 0; d < D; ++d) |
| 91 | + center[d] = parameters[1 + d]; |
| 92 | + REAL offset = parameters[1 + D]; |
| 93 | + |
| 94 | + // --- Compute shifted input coords: pt[d] = coords[d] - center[d] --- |
| 95 | + REAL pt[NATURAL_BSPLINE_ND_MAX_DIMS]; |
| 96 | + for (d = 0; d < D; ++d) |
| 97 | + pt[d] = static_cast<REAL>(coords[d]) - center[d]; |
| 98 | + |
| 99 | + // --- For each axis: find knot span, evaluate basis and derivative --- |
| 100 | + int k = CUBIC_DEGREE; |
| 101 | + int span[NATURAL_BSPLINE_ND_MAX_DIMS]; |
| 102 | + REAL B[NATURAL_BSPLINE_ND_MAX_DIMS][CUBIC_BASIS_SIZE]; |
| 103 | + REAL dB[NATURAL_BSPLINE_ND_MAX_DIMS][CUBIC_BASIS_SIZE]; |
| 104 | + |
| 105 | + for (d = 0; d < D; ++d) |
| 106 | + { |
| 107 | + int N = data_points[d]; |
| 108 | + int M = n_ctrl[d]; |
| 109 | + // Find knot span |
| 110 | + REAL xq = pt[d]; |
| 111 | + if (xq <= 0.0) |
| 112 | + span[d] = k; |
| 113 | + else if (xq >= REAL(N - 1)) |
| 114 | + span[d] = M - 1; |
| 115 | + else |
| 116 | + span[d] = static_cast<int>(xq) + k; |
| 117 | + // Basis |
| 118 | + evaluate_fast_cubic_basis_device(xq, span[d], knots[d], M, B[d]); |
| 119 | + evaluate_fast_cubic_basis_derivative_device(xq, span[d], knots[d], M, dB[d]); |
| 120 | + } |
| 121 | + |
| 122 | + // --- Tensor product sum --- |
| 123 | + // Setup for up to 6D (unroll for speed, expand if needed) |
| 124 | + REAL spline_val = 0; |
| 125 | + REAL spline_dx[NATURAL_BSPLINE_ND_MAX_DIMS] = {0}; |
| 126 | + int stride[NATURAL_BSPLINE_ND_MAX_DIMS]; |
| 127 | + int offset[NATURAL_BSPLINE_ND_MAX_DIMS]; |
| 128 | + |
| 129 | + for (d = 0; d < D; ++d) |
| 130 | + { |
| 131 | + stride[d] = n_stride[d]; |
| 132 | + offset[d] = span[d] - k; |
| 133 | + } |
| 134 | + |
| 135 | + // Only support up to 6D for hardcoded loops (can extend with templates if needed) |
| 136 | + if (D == 1) |
| 137 | + { |
| 138 | + for (int i0 = 0; i0 < 4; ++i0) |
| 139 | + { |
| 140 | + int idx0 = (offset[0] + i0) * stride[0]; |
| 141 | + REAL c = coeff[idx0]; |
| 142 | + REAL w = B[0][i0]; |
| 143 | + spline_val += w * c; |
| 144 | + spline_dx[0] += dB[0][i0] * c; |
| 145 | + } |
| 146 | + } |
| 147 | + else if (D == 2) |
| 148 | + { |
| 149 | + for (int i0 = 0; i0 < 4; ++i0) |
| 150 | + for (int i1 = 0; i1 < 4; ++i1) |
| 151 | + { |
| 152 | + int idx = (offset[0] + i0) * stride[0] + (offset[1] + i1) * stride[1]; |
| 153 | + REAL c = coeff[idx]; |
| 154 | + REAL w = B[0][i0] * B[1][i1]; |
| 155 | + spline_val += w * c; |
| 156 | + spline_dx[0] += dB[0][i0] * B[1][i1] * c; |
| 157 | + spline_dx[1] += B[0][i0] * dB[1][i1] * c; |
| 158 | + } |
| 159 | + } |
| 160 | + else if (D == 3) |
| 161 | + { |
| 162 | + for (int i0 = 0; i0 < 4; ++i0) |
| 163 | + for (int i1 = 0; i1 < 4; ++i1) |
| 164 | + for (int i2 = 0; i2 < 4; ++i2) |
| 165 | + { |
| 166 | + int idx = (offset[0] + i0) * stride[0] + |
| 167 | + (offset[1] + i1) * stride[1] + |
| 168 | + (offset[2] + i2) * stride[2]; |
| 169 | + REAL c = coeff[idx]; |
| 170 | + REAL w = B[0][i0] * B[1][i1] * B[2][i2]; |
| 171 | + spline_val += w * c; |
| 172 | + spline_dx[0] += dB[0][i0] * B[1][i1] * B[2][i2] * c; |
| 173 | + spline_dx[1] += B[0][i0] * dB[1][i1] * B[2][i2] * c; |
| 174 | + spline_dx[2] += B[0][i0] * B[1][i1] * dB[2][i2] * c; |
| 175 | + } |
| 176 | + } |
| 177 | + else if (D == 4) |
| 178 | + { |
| 179 | + for (int i0 = 0; i0 < 4; ++i0) |
| 180 | + for (int i1 = 0; i1 < 4; ++i1) |
| 181 | + for (int i2 = 0; i2 < 4; ++i2) |
| 182 | + for (int i3 = 0; i3 < 4; ++i3) |
| 183 | + { |
| 184 | + int idx = (offset[0] + i0) * stride[0] + |
| 185 | + (offset[1] + i1) * stride[1] + |
| 186 | + (offset[2] + i2) * stride[2] + |
| 187 | + (offset[3] + i3) * stride[3]; |
| 188 | + REAL c = coeff[idx]; |
| 189 | + REAL w = B[0][i0] * B[1][i1] * B[2][i2] * B[3][i3]; |
| 190 | + spline_val += w * c; |
| 191 | + spline_dx[0] += dB[0][i0] * B[1][i1] * B[2][i2] * B[3][i3] * c; |
| 192 | + spline_dx[1] += B[0][i0] * dB[1][i1] * B[2][i2] * B[3][i3] * c; |
| 193 | + spline_dx[2] += B[0][i0] * B[1][i1] * dB[2][i2] * B[3][i3] * c; |
| 194 | + spline_dx[3] += B[0][i0] * B[1][i1] * B[2][i2] * dB[3][i3] * c; |
| 195 | + } |
| 196 | + } |
| 197 | + else if (D == 5) |
| 198 | + { |
| 199 | + for (int i0 = 0; i0 < 4; ++i0) |
| 200 | + for (int i1 = 0; i1 < 4; ++i1) |
| 201 | + for (int i2 = 0; i2 < 4; ++i2) |
| 202 | + for (int i3 = 0; i3 < 4; ++i3) |
| 203 | + for (int i4 = 0; i4 < 4; ++i4) |
| 204 | + { |
| 205 | + int idx = (offset[0] + i0) * stride[0] + |
| 206 | + (offset[1] + i1) * stride[1] + |
| 207 | + (offset[2] + i2) * stride[2] + |
| 208 | + (offset[3] + i3) * stride[3] + |
| 209 | + (offset[4] + i4) * stride[4]; |
| 210 | + REAL c = coeff[idx]; |
| 211 | + REAL w = B[0][i0] * B[1][i1] * B[2][i2] * B[3][i3] * B[4][i4]; |
| 212 | + spline_val += w * c; |
| 213 | + spline_dx[0] += dB[0][i0] * B[1][i1] * B[2][i2] * B[3][i3] * B[4][i4] * c; |
| 214 | + spline_dx[1] += B[0][i0] * dB[1][i1] * B[2][i2] * B[3][i3] * B[4][i4] * c; |
| 215 | + spline_dx[2] += B[0][i0] * B[1][i1] * dB[2][i2] * B[3][i3] * B[4][i4] * c; |
| 216 | + spline_dx[3] += B[0][i0] * B[1][i1] * B[2][i2] * dB[3][i3] * B[4][i4] * c; |
| 217 | + spline_dx[4] += B[0][i0] * B[1][i1] * B[2][i2] * B[3][i3] * dB[4][i4] * c; |
| 218 | + } |
| 219 | + } |
| 220 | + else if (D == 6) |
| 221 | + { |
| 222 | + for (int i0 = 0; i0 < 4; ++i0) |
| 223 | + for (int i1 = 0; i1 < 4; ++i1) |
| 224 | + for (int i2 = 0; i2 < 4; ++i2) |
| 225 | + for (int i3 = 0; i3 < 4; ++i3) |
| 226 | + for (int i4 = 0; i4 < 4; ++i4) |
| 227 | + for (int i5 = 0; i5 < 4; ++i5) |
| 228 | + { |
| 229 | + int idx = (offset[0] + i0) * stride[0] + |
| 230 | + (offset[1] + i1) * stride[1] + |
| 231 | + (offset[2] + i2) * stride[2] + |
| 232 | + (offset[3] + i3) * stride[3] + |
| 233 | + (offset[4] + i4) * stride[4] + |
| 234 | + (offset[5] + i5) * stride[5]; |
| 235 | + REAL c = coeff[idx]; |
| 236 | + REAL w = B[0][i0] * B[1][i1] * B[2][i2] * B[3][i3] * B[4][i4] * B[5][i5]; |
| 237 | + spline_val += w * c; |
| 238 | + spline_dx[0] += dB[0][i0] * B[1][i1] * B[2][i2] * B[3][i3] * B[4][i4] * B[5][i5] * c; |
| 239 | + spline_dx[1] += B[0][i0] * dB[1][i1] * B[2][i2] * B[3][i3] * B[4][i4] * B[5][i5] * c; |
| 240 | + spline_dx[2] += B[0][i0] * B[1][i1] * dB[2][i2] * B[3][i3] * B[4][i4] * B[5][i5] * c; |
| 241 | + spline_dx[3] += B[0][i0] * B[1][i1] * B[2][i2] * dB[3][i3] * B[4][i4] * B[5][i5] * c; |
| 242 | + spline_dx[4] += B[0][i0] * B[1][i1] * B[2][i2] * B[3][i3] * dB[4][i4] * B[5][i5] * c; |
| 243 | + spline_dx[5] += B[0][i0] * B[1][i1] * B[2][i2] * B[3][i3] * B[4][i4] * dB[5][i5] * c; |
| 244 | + } |
| 245 | + } |
| 246 | + |
| 247 | + // --- Output model value --- |
| 248 | + value[point_index] = amp * spline_val + offset; |
| 249 | + |
| 250 | + // --- Output derivatives --- |
| 251 | + REAL * der = derivative + point_index; |
| 252 | + der[0 * n_points] = spline_val; // d/d(amp) |
| 253 | + for (d = 0; d < D; ++d) |
| 254 | + der[(1 + d) * n_points] = -amp * spline_dx[d]; // d/d(center_d) |
| 255 | + der[(1 + D) * n_points] = 1; // d/d(offset) |
| 256 | +} |
| 257 | + |
| 258 | +#endif |
0 commit comments