Skip to content

Commit b9b39e2

Browse files
committed
updates to B-spline models, added N-D fitting
1 parent 154965f commit b9b39e2

File tree

3 files changed

+280
-19
lines changed

3 files changed

+280
-19
lines changed

Gpufit/models/natural_bspline_1d.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
* chunk_index: chunk number
3737
*
3838
* user_info: passed-in buffer with spline meta-data:
39-
* user_info[0] = num_control_points
39+
* user_info[0] = num_control_points
4040
* user_info[1...num_control_points+4] = knot vector (float)
4141
* user_info[1+num_control_points+4 ...] = coefficients
4242
*
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
#ifndef GPUFIT_NATURAL_BSPLINE_ND_CUH_INCLUDED
2+
#define GPUFIT_NATURAL_BSPLINE_ND_CUH_INCLUDED
3+
4+
#include "bspline_fast_cubic_basis_evaluate_device.cuh"
5+
6+
// Max number of supported dimensions (adjust as needed)
7+
#define NATURAL_BSPLINE_ND_MAX_DIMS 6
8+
#define CUBIC_DEGREE 3
9+
#define CUBIC_BASIS_SIZE 4
10+
11+
/*
12+
* Description of the calculate_natural_bspline_nd function
13+
* ========================================================
14+
*
15+
* parameters: [amp, center_0, ..., center_{D-1}, offset]
16+
* n_fits: number of fits (for Gpufit batch)
17+
* n_points: number of points per fit
18+
* value: output model values
19+
* derivative: output derivatives
20+
* point_index: index of the current point (for each axis, see below)
21+
* fit_index: index of the current fit
22+
* chunk_index: chunk index
23+
* user_info: buffer holding all ND spline metadata in the following order:
24+
* user_info[0] = D (number of dims)
25+
* user_info[1 ... D] = number of data points per axis
26+
* user_info[1+D ... 1+2D] = number of control points per axis
27+
* user_info[1+2D ... 1+3D] = number of knots per axis
28+
* user_info[... next] = flattened knot vectors (all axes, all knots per axis)
29+
* user_info[... next] = flattened control point strides (for flattened coeff tensor)
30+
* user_info[... next] = coefficients (flattened, length = product of control points per axis)
31+
*
32+
* For best performance, precompute and pack all these arrays on the host using your Gpuspline code.
33+
*/
34+
35+
__device__ void calculate_natural_bspline_nd(
36+
REAL const * parameters, // [amp, center_0..center_{D-1}, offset]
37+
int const n_fits,
38+
int const n_points,
39+
REAL * value,
40+
REAL * derivative,
41+
int const point_index,
42+
int const fit_index,
43+
int const chunk_index,
44+
char * user_info,
45+
std::size_t const user_info_size)
46+
{
47+
// --- Unpack user_info ---
48+
REAL const * ui = (REAL*)user_info;
49+
50+
int D = static_cast<int>(ui[0]);
51+
int d;
52+
int const * data_points = (int const *)(ui + 1); // [D]
53+
int const * control_points = data_points + D; // [D]
54+
int const * num_knots = control_points + D; // [D]
55+
56+
int offset_knots = 1 + 3 * D;
57+
int offset_strides = offset_knots;
58+
for (d = 0; d < D; ++d) offset_strides += num_knots[d];
59+
int offset_coeff = offset_strides + D;
60+
61+
// --- Pointers to ND arrays ---
62+
REAL const * knots[NATURAL_BSPLINE_ND_MAX_DIMS];
63+
int n_ctrl[NATURAL_BSPLINE_ND_MAX_DIMS];
64+
int n_stride[NATURAL_BSPLINE_ND_MAX_DIMS];
65+
int n_span[NATURAL_BSPLINE_ND_MAX_DIMS];
66+
67+
int acc = offset_knots;
68+
for (d = 0; d < D; ++d)
69+
{
70+
knots[d] = ui + acc;
71+
acc += num_knots[d];
72+
n_ctrl[d] = control_points[d];
73+
n_stride[d] = static_cast<int>(ui[offset_strides + d]);
74+
}
75+
REAL const * coeff = ui + offset_coeff;
76+
77+
// --- Map point_index to ND coordinates (x[0], ..., x[D-1]) ---
78+
// For image or ND array, this is usually unravel_index(point_index, data_points)
79+
int coords[NATURAL_BSPLINE_ND_MAX_DIMS] = { 0 };
80+
int idx = point_index;
81+
for (d = 0; d < D; ++d)
82+
{
83+
coords[d] = idx % data_points[d];
84+
idx /= data_points[d];
85+
}
86+
87+
// --- Unpack model parameters ---
88+
REAL amp = parameters[0];
89+
REAL center[NATURAL_BSPLINE_ND_MAX_DIMS];
90+
for (d = 0; d < D; ++d)
91+
center[d] = parameters[1 + d];
92+
REAL offset = parameters[1 + D];
93+
94+
// --- Compute shifted input coords: pt[d] = coords[d] - center[d] ---
95+
REAL pt[NATURAL_BSPLINE_ND_MAX_DIMS];
96+
for (d = 0; d < D; ++d)
97+
pt[d] = static_cast<REAL>(coords[d]) - center[d];
98+
99+
// --- For each axis: find knot span, evaluate basis and derivative ---
100+
int k = CUBIC_DEGREE;
101+
int span[NATURAL_BSPLINE_ND_MAX_DIMS];
102+
REAL B[NATURAL_BSPLINE_ND_MAX_DIMS][CUBIC_BASIS_SIZE];
103+
REAL dB[NATURAL_BSPLINE_ND_MAX_DIMS][CUBIC_BASIS_SIZE];
104+
105+
for (d = 0; d < D; ++d)
106+
{
107+
int N = data_points[d];
108+
int M = n_ctrl[d];
109+
// Find knot span
110+
REAL xq = pt[d];
111+
if (xq <= 0.0)
112+
span[d] = k;
113+
else if (xq >= REAL(N - 1))
114+
span[d] = M - 1;
115+
else
116+
span[d] = static_cast<int>(xq) + k;
117+
// Basis
118+
evaluate_fast_cubic_basis_device(xq, span[d], knots[d], M, B[d]);
119+
evaluate_fast_cubic_basis_derivative_device(xq, span[d], knots[d], M, dB[d]);
120+
}
121+
122+
// --- Tensor product sum ---
123+
// Setup for up to 6D (unroll for speed, expand if needed)
124+
REAL spline_val = 0;
125+
REAL spline_dx[NATURAL_BSPLINE_ND_MAX_DIMS] = {0};
126+
int stride[NATURAL_BSPLINE_ND_MAX_DIMS];
127+
int offset[NATURAL_BSPLINE_ND_MAX_DIMS];
128+
129+
for (d = 0; d < D; ++d)
130+
{
131+
stride[d] = n_stride[d];
132+
offset[d] = span[d] - k;
133+
}
134+
135+
// Only support up to 6D for hardcoded loops (can extend with templates if needed)
136+
if (D == 1)
137+
{
138+
for (int i0 = 0; i0 < 4; ++i0)
139+
{
140+
int idx0 = (offset[0] + i0) * stride[0];
141+
REAL c = coeff[idx0];
142+
REAL w = B[0][i0];
143+
spline_val += w * c;
144+
spline_dx[0] += dB[0][i0] * c;
145+
}
146+
}
147+
else if (D == 2)
148+
{
149+
for (int i0 = 0; i0 < 4; ++i0)
150+
for (int i1 = 0; i1 < 4; ++i1)
151+
{
152+
int idx = (offset[0] + i0) * stride[0] + (offset[1] + i1) * stride[1];
153+
REAL c = coeff[idx];
154+
REAL w = B[0][i0] * B[1][i1];
155+
spline_val += w * c;
156+
spline_dx[0] += dB[0][i0] * B[1][i1] * c;
157+
spline_dx[1] += B[0][i0] * dB[1][i1] * c;
158+
}
159+
}
160+
else if (D == 3)
161+
{
162+
for (int i0 = 0; i0 < 4; ++i0)
163+
for (int i1 = 0; i1 < 4; ++i1)
164+
for (int i2 = 0; i2 < 4; ++i2)
165+
{
166+
int idx = (offset[0] + i0) * stride[0] +
167+
(offset[1] + i1) * stride[1] +
168+
(offset[2] + i2) * stride[2];
169+
REAL c = coeff[idx];
170+
REAL w = B[0][i0] * B[1][i1] * B[2][i2];
171+
spline_val += w * c;
172+
spline_dx[0] += dB[0][i0] * B[1][i1] * B[2][i2] * c;
173+
spline_dx[1] += B[0][i0] * dB[1][i1] * B[2][i2] * c;
174+
spline_dx[2] += B[0][i0] * B[1][i1] * dB[2][i2] * c;
175+
}
176+
}
177+
else if (D == 4)
178+
{
179+
for (int i0 = 0; i0 < 4; ++i0)
180+
for (int i1 = 0; i1 < 4; ++i1)
181+
for (int i2 = 0; i2 < 4; ++i2)
182+
for (int i3 = 0; i3 < 4; ++i3)
183+
{
184+
int idx = (offset[0] + i0) * stride[0] +
185+
(offset[1] + i1) * stride[1] +
186+
(offset[2] + i2) * stride[2] +
187+
(offset[3] + i3) * stride[3];
188+
REAL c = coeff[idx];
189+
REAL w = B[0][i0] * B[1][i1] * B[2][i2] * B[3][i3];
190+
spline_val += w * c;
191+
spline_dx[0] += dB[0][i0] * B[1][i1] * B[2][i2] * B[3][i3] * c;
192+
spline_dx[1] += B[0][i0] * dB[1][i1] * B[2][i2] * B[3][i3] * c;
193+
spline_dx[2] += B[0][i0] * B[1][i1] * dB[2][i2] * B[3][i3] * c;
194+
spline_dx[3] += B[0][i0] * B[1][i1] * B[2][i2] * dB[3][i3] * c;
195+
}
196+
}
197+
else if (D == 5)
198+
{
199+
for (int i0 = 0; i0 < 4; ++i0)
200+
for (int i1 = 0; i1 < 4; ++i1)
201+
for (int i2 = 0; i2 < 4; ++i2)
202+
for (int i3 = 0; i3 < 4; ++i3)
203+
for (int i4 = 0; i4 < 4; ++i4)
204+
{
205+
int idx = (offset[0] + i0) * stride[0] +
206+
(offset[1] + i1) * stride[1] +
207+
(offset[2] + i2) * stride[2] +
208+
(offset[3] + i3) * stride[3] +
209+
(offset[4] + i4) * stride[4];
210+
REAL c = coeff[idx];
211+
REAL w = B[0][i0] * B[1][i1] * B[2][i2] * B[3][i3] * B[4][i4];
212+
spline_val += w * c;
213+
spline_dx[0] += dB[0][i0] * B[1][i1] * B[2][i2] * B[3][i3] * B[4][i4] * c;
214+
spline_dx[1] += B[0][i0] * dB[1][i1] * B[2][i2] * B[3][i3] * B[4][i4] * c;
215+
spline_dx[2] += B[0][i0] * B[1][i1] * dB[2][i2] * B[3][i3] * B[4][i4] * c;
216+
spline_dx[3] += B[0][i0] * B[1][i1] * B[2][i2] * dB[3][i3] * B[4][i4] * c;
217+
spline_dx[4] += B[0][i0] * B[1][i1] * B[2][i2] * B[3][i3] * dB[4][i4] * c;
218+
}
219+
}
220+
else if (D == 6)
221+
{
222+
for (int i0 = 0; i0 < 4; ++i0)
223+
for (int i1 = 0; i1 < 4; ++i1)
224+
for (int i2 = 0; i2 < 4; ++i2)
225+
for (int i3 = 0; i3 < 4; ++i3)
226+
for (int i4 = 0; i4 < 4; ++i4)
227+
for (int i5 = 0; i5 < 4; ++i5)
228+
{
229+
int idx = (offset[0] + i0) * stride[0] +
230+
(offset[1] + i1) * stride[1] +
231+
(offset[2] + i2) * stride[2] +
232+
(offset[3] + i3) * stride[3] +
233+
(offset[4] + i4) * stride[4] +
234+
(offset[5] + i5) * stride[5];
235+
REAL c = coeff[idx];
236+
REAL w = B[0][i0] * B[1][i1] * B[2][i2] * B[3][i3] * B[4][i4] * B[5][i5];
237+
spline_val += w * c;
238+
spline_dx[0] += dB[0][i0] * B[1][i1] * B[2][i2] * B[3][i3] * B[4][i4] * B[5][i5] * c;
239+
spline_dx[1] += B[0][i0] * dB[1][i1] * B[2][i2] * B[3][i3] * B[4][i4] * B[5][i5] * c;
240+
spline_dx[2] += B[0][i0] * B[1][i1] * dB[2][i2] * B[3][i3] * B[4][i4] * B[5][i5] * c;
241+
spline_dx[3] += B[0][i0] * B[1][i1] * B[2][i2] * dB[3][i3] * B[4][i4] * B[5][i5] * c;
242+
spline_dx[4] += B[0][i0] * B[1][i1] * B[2][i2] * B[3][i3] * dB[4][i4] * B[5][i5] * c;
243+
spline_dx[5] += B[0][i0] * B[1][i1] * B[2][i2] * B[3][i3] * B[4][i4] * dB[5][i5] * c;
244+
}
245+
}
246+
247+
// --- Output model value ---
248+
value[point_index] = amp * spline_val + offset;
249+
250+
// --- Output derivatives ---
251+
REAL * der = derivative + point_index;
252+
der[0 * n_points] = spline_val; // d/d(amp)
253+
for (d = 0; d < D; ++d)
254+
der[(1 + d) * n_points] = -amp * spline_dx[d]; // d/d(center_d)
255+
der[(1 + D) * n_points] = 1; // d/d(offset)
256+
}
257+
258+
#endif

Gpufit/models/spline_1d.cuh

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -75,47 +75,50 @@ __device__ void calculate_spline1d(
7575
// read user_info
7676
REAL const * user_info_REAL = (REAL *)user_info;
7777

78-
int const n_intervals = static_cast<int>(*user_info_REAL);
79-
std::size_t const n_coefficients_per_interval = 4;
78+
std::size_t const n_points_x = static_cast<std::size_t>(*(user_info_REAL + 0));
79+
int const n_intervals_x = static_cast<int>(*(user_info_REAL + 1));
8080

81-
REAL const * coefficients = user_info_REAL + 1;
81+
std::size_t const n_coefficients_per_interval = 4;
82+
REAL const * coefficients = user_info_REAL + 2;
8283

8384
// parameters
8485
REAL const * p = parameters;
8586

8687
// estimate index i of the current spline interval
87-
REAL const x = static_cast<REAL>(point_index);
88-
REAL const position = x - p[1];
89-
int i = static_cast<int>(floor(position)); // can be negative
88+
REAL const position_x = point_index - p[1];
89+
int i = static_cast<int>(floor(position_x));
9090

9191
// adjust i to its bounds
9292
i = i >= 0 ? i : 0;
93-
i = i < n_intervals ? i : n_intervals - 1;
93+
i = i < n_intervals_x ? i : n_intervals_x - 1;
9494

9595
// get coefficients of the current interval
9696
REAL const * current_coefficients = coefficients + i * n_coefficients_per_interval;
9797

98-
// calculate position relative to the current spline interval
99-
REAL const x_diff = position - static_cast<REAL>(i);
98+
// estimate position relative to the current spline interval
99+
REAL const x_diff = position_x - i;
100100

101101
// intermediate values
102102
REAL temp_value = 0;
103103
REAL temp_derivative_1 = 0;
104104

105-
REAL power_factor = 1;
106-
for (std::size_t order = 0; order < n_coefficients_per_interval; order++)
105+
REAL power_factor_i = 1;
106+
for (int order_i = 0; order_i < 4; order_i++)
107107
{
108+
108109
// intermediate function value without amplitude and offset
109-
temp_value += current_coefficients[order] * power_factor;
110+
temp_value += current_coefficients[order_i] * power_factor_i;
110111

111-
// intermediate derivative value with respect to paramater 1 (center position)
112-
if (order < n_coefficients_per_interval - 1)
112+
// intermediate derivative value with respect to paramater 1 (center position x)
113+
if (order_i < 3)
114+
{
113115
temp_derivative_1
114-
+= (REAL(order) + 1)
115-
* current_coefficients[order + 1]
116-
* power_factor;
116+
+= (REAL(order_i) + 1)
117+
* current_coefficients[(order_i + 1)]
118+
* power_factor_i;
119+
}
117120

118-
power_factor *= x_diff;
121+
power_factor_i *= x_diff;
119122
}
120123

121124
// value

0 commit comments

Comments
 (0)