Skip to content

Commit 154965f

Browse files
committed
new fit models including natural b-splines
1 parent c2f520d commit 154965f

File tree

6 files changed

+509
-6
lines changed

6 files changed

+509
-6
lines changed

Gpufit/constants.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ enum ModelID {
1818
SPLINE_3D_MULTICHANNEL = 11,
1919
SPLINE_3D_PHASE_MULTICHANNEL = 12,
2020
SPLINE_4D = 13,
21-
SPLINE_5D = 14
21+
SPLINE_4D_MULTICHANNEL = 14,
22+
SPLINE_5D = 15,
23+
NATURAL_BSPLINE_1D = 16
2224
};
2325

2426
// estimator ID

Gpufit/models/models.cuh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
#include "spline_3d_multichannel.cuh"
1717
#include "spline_3d_phase_multichannel.cuh"
1818
#include "spline_4d.cuh"
19+
#include "spline_4d_multichannel.cuh"
1920
#include "spline_5d.cuh"
21+
#include "natural_bspline_1d.cuh"
2022

2123
__device__ void calculate_model(
2224
ModelID const model_id,
@@ -75,9 +77,15 @@ __device__ void calculate_model(
7577
case SPLINE_4D:
7678
calculate_spline4d(parameters, n_fits, n_points, value, derivative, point_index, fit_index, chunk_index, user_info, user_info_size);
7779
break;
80+
case SPLINE_4D_MULTICHANNEL:
81+
calculate_spline4d_multichannel(parameters, n_fits, n_points, value, derivative, point_index, fit_index, chunk_index, user_info, user_info_size);
82+
break;
7883
case SPLINE_5D:
7984
calculate_spline5d(parameters, n_fits, n_points, value, derivative, point_index, fit_index, chunk_index, user_info, user_info_size);
8085
break;
86+
case NATURAL_BSPLINE_1D:
87+
calculate_natural_bspline1d(parameters, n_fits, n_points, value, derivative, point_index, fit_index, chunk_index, user_info, user_info_size);
88+
break;
8189
default:
8290
assert(0); // unknown model ID
8391
}
@@ -101,7 +109,9 @@ void configure_model(ModelID const model_id, int & n_parameters, int & n_dimensi
101109
case SPLINE_3D_MULTICHANNEL: n_parameters = 5; n_dimensions = 4; break;
102110
case SPLINE_3D_PHASE_MULTICHANNEL: n_parameters = 6; n_dimensions = 4; break;
103111
case SPLINE_4D: n_parameters = 6; n_dimensions = 4; break;
112+
case SPLINE_4D_MULTICHANNEL: n_parameters = 6; n_dimensions = 5; break;
104113
case SPLINE_5D: n_parameters = 7; n_dimensions = 5; break;
114+
case NATURAL_BSPLINE_1D: n_parameters = 3; n_dimensions = 1; break;
105115
default: throw std::runtime_error("unknown model ID");
106116
}
107117
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#ifndef GPUFIT_NATURAL_BSPLINE1D_CUH_INCLUDED
2+
#define GPUFIT_NATURAL_BSPLINE1D_CUH_INCLUDED
3+
4+
#include "natural_bspline_fast_cubic_basis_device.cuh"
5+
6+
/*
7+
* Description of the calculate_natural_bspline1d function
8+
* =======================================================
9+
*
10+
* This function calculates a value of a one-dimensional natural cubic B-spline model function
11+
* and its partial derivatives with respect to the model parameters.
12+
*
13+
* The X coordinate of the first data value is assumed to be 0.0. For
14+
* a fit size of N data points, the X coordinates of the data are
15+
* simply the corresponding array index values of the data array, starting from zero.
16+
*
17+
* Parameters:
18+
*
19+
* parameters: An input vector of concatenated sets of model parameters.
20+
* p[0]: amplitude
21+
* p[1]: center coordinate
22+
* p[2]: offset
23+
*
24+
* n_fits: The number of fits.
25+
*
26+
* n_points: The number of data points per fit.
27+
*
28+
* value: output model values (for all points in this fit)
29+
*
30+
* derivative: output derivatives (with respect to each parameter)
31+
*
32+
* point_index: current point (x = point_index)
33+
*
34+
* fit_index: fit number
35+
*
36+
* chunk_index: chunk number
37+
*
38+
* user_info: passed-in buffer with spline meta-data:
39+
* user_info[0] = num_control_points
40+
* user_info[1...num_control_points+4] = knot vector (float)
41+
* user_info[1+num_control_points+4 ...] = coefficients
42+
*
43+
* user_info_size: number of elements in user_info (bytes)
44+
*
45+
*/
46+
47+
__device__ void calculate_natural_bspline1d(
48+
REAL const * parameters,
49+
int const n_fits,
50+
int const n_points,
51+
REAL * value,
52+
REAL * derivative,
53+
int const point_index,
54+
int const fit_index,
55+
int const chunk_index,
56+
char * user_info,
57+
std::size_t const user_info_size)
58+
{
59+
// Read user_info buffer as REAL
60+
REAL const * ui = (REAL*)user_info;
61+
62+
int const num_coeff = static_cast<int>(ui[0]);
63+
int const num_knots = num_coeff + 4; // cubic
64+
REAL const * knots = ui + 1;
65+
REAL const * coeff = ui + 1 + num_knots;
66+
67+
// Model parameters: [amp, center, offset]
68+
REAL const * p = parameters;
69+
REAL amp = p[0];
70+
REAL center = p[1];
71+
REAL offset = p[2];
72+
73+
// Data point coordinate
74+
REAL x = static_cast<REAL>(point_index);
75+
REAL xq = x - center;
76+
77+
// Find knot span as in host code
78+
const int k = 3;
79+
int N = num_coeff - 2;
80+
int span;
81+
if (xq <= 0.0)
82+
span = k;
83+
else if (xq >= REAL(N - 1))
84+
span = num_coeff - 1;
85+
else
86+
span = int(xq) + k;
87+
88+
// Evaluate basis and derivative (device function!)
89+
REAL basis[4], dbasis[4];
90+
nat_bspl_fast_cubic_basis_device(xq, span, knots, num_coeff, basis);
91+
nat_bspl_fast_cubic_basis_derivative_device(xq, span, knots, num_coeff, dbasis);
92+
93+
// Compute value and d/dx
94+
REAL spline_val = 0, spline_dx = 0;
95+
for (int i = 0; i < 4; ++i)
96+
{
97+
int idx = span - k + i;
98+
if (idx >= 0 && idx < num_coeff)
99+
{
100+
spline_val += basis[i] * coeff[idx];
101+
spline_dx += dbasis[i] * coeff[idx];
102+
}
103+
}
104+
105+
// Write value
106+
value[point_index] = amp * spline_val + offset;
107+
108+
// Write derivatives [amp, center, offset]
109+
REAL * der = derivative + point_index;
110+
der[0 * n_points] = spline_val; // d/d(amp)
111+
der[1 * n_points] = -amp * spline_dx; // d/d(center)
112+
der[2 * n_points] = 1; // d/d(offset)
113+
}
114+
115+
#endif
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
#ifndef GPUFIT_NAT_BSPL_BASIS_CUH_INCLUDED
2+
#define GPUFIT_NAT_BSPL_BASIS_CUH_INCLUDED
3+
4+
// Replace REAL with float or double as needed for your GPU build.
5+
6+
__device__ void nat_bspl_fast_cubic_basis_device(
7+
float x, // query coordinate
8+
int span, // knot interval
9+
const float* knots, // pointer to knot vector (device)
10+
int num_control_points, // # control points for the axis
11+
float* basis // output [4]
12+
)
13+
{
14+
const int k = 3; // Cubic
15+
float t = x - knots[span];
16+
17+
if (span == k) {
18+
// First support interval
19+
// Paste: evaluate_fast_cubic_basis_first_span
20+
float t2 = t * t;
21+
float t3 = t2 * t;
22+
basis[0] = 1.0f - 3.0f * t + 3.0f * t2 - t3;
23+
basis[1] = 3.0f * t - 4.5f * t2 + 1.75f * t3;
24+
basis[2] = 1.5f * t2 - (11.0f / 12.0f) * t3;
25+
basis[3] = t3 / 6.0f;
26+
}
27+
else if (span == k + 1) {
28+
// Second support interval
29+
float t2 = t * t;
30+
float t3 = t2 * t;
31+
basis[0] = 0.25f - 0.75f * t + 0.75f * t2 - 0.25f * t3;
32+
basis[1] = (7.0f / 12.0f) + 0.25f * t - 1.25f * t2 + (7.0f / 12.0f) * t3;
33+
basis[2] = (1.0f / 6.0f) + 0.5f * t + 0.5f * t2 - 0.5f * t3;
34+
basis[3] = t3 / 6.0f;
35+
}
36+
else if (span >= k + 2 && span <= num_control_points - 3) {
37+
// Interior
38+
float t2 = t * t;
39+
float t3 = t2 * t;
40+
float omt = 1.0f - t;
41+
basis[0] = (1.0f / 6.0f) * omt * omt * omt;
42+
basis[1] = (1.0f / 6.0f) * (3.0f * t3 - 6.0f * t2 + 4.0f);
43+
basis[2] = (1.0f / 6.0f) * (-3.0f * t3 + 3.0f * t2 + 3.0f * t + 1.0f);
44+
basis[3] = (1.0f / 6.0f) * t3;
45+
}
46+
else if (span == num_control_points - 2) {
47+
// Second last span
48+
float omt = 1.0f - t;
49+
float omt2 = omt * omt;
50+
float omt3 = omt2 * omt;
51+
basis[0] = omt3 / 6.0f;
52+
basis[1] = (1.0f / 6.0f) + 0.5f * omt + 0.5f * omt2 - 0.5f * omt3;
53+
basis[2] = (7.0f / 12.0f) + 0.25f * omt - 1.25f * omt2 + (7.0f / 12.0f) * omt3;
54+
basis[3] = 0.25f - 0.75f * omt + 0.75f * omt2 - 0.25f * omt3;
55+
}
56+
else if (span == num_control_points - 1) {
57+
// Last span
58+
float omt = 1.0f - t;
59+
float omt2 = omt * omt;
60+
float omt3 = omt2 * omt;
61+
basis[0] = omt3 / 6.0f;
62+
basis[1] = 1.5f * omt2 - (11.0f / 12.0f) * omt3;
63+
basis[2] = 3.0f * omt - 4.5f * omt2 + 1.75f * omt3;
64+
basis[3] = 1.0f - 3.0f * omt + 3.0f * omt2 - omt3;
65+
}
66+
}
67+
68+
69+
__device__ void nat_bspl_fast_cubic_basis_derivative_device(
70+
float x,
71+
int span,
72+
const float* knots,
73+
int num_control_points,
74+
float* dbasis
75+
)
76+
{
77+
const int k = 3; // Cubic
78+
float t = x - knots[span];
79+
80+
if (span == k) {
81+
float t2 = t * t;
82+
dbasis[0] = -3.0f + 6.0f * t - 3.0f * t2;
83+
dbasis[1] = 3.0f - 9.0f * t + 5.25f * t2;
84+
dbasis[2] = 3.0f * t - 2.75f * t2;
85+
dbasis[3] = 0.5f * t2;
86+
}
87+
else if (span == k + 1) {
88+
float t2 = t * t;
89+
dbasis[0] = -0.75f + 1.5f * t - 0.75f * t2;
90+
dbasis[1] = 0.25f - 2.5f * t + 1.75f * t2;
91+
dbasis[2] = 0.5f + t - 1.5f * t2;
92+
dbasis[3] = 0.5f * t2;
93+
}
94+
else if (span >= k + 2 && span <= num_control_points - 3) {
95+
float t2 = t * t;
96+
float omt = 1.0f - t;
97+
dbasis[0] = -0.5f * omt * omt;
98+
dbasis[1] = 1.5f * t2 - 2.0f * t;
99+
dbasis[2] = -1.5f * t2 + t + 0.5f;
100+
dbasis[3] = 0.5f * t2;
101+
}
102+
else if (span == num_control_points - 2) {
103+
float omt = 1.0f - t;
104+
float omt2 = omt * omt;
105+
dbasis[0] = -0.5f * omt2;
106+
dbasis[1] = -0.5f - omt + 1.5f * omt2;
107+
dbasis[2] = -0.25f + 2.5f * omt - 1.75f * omt2;
108+
dbasis[3] = 0.75f - 1.5f * omt + 0.75f * omt2;
109+
}
110+
else if (span == num_control_points - 1) {
111+
float omt = 1.0f - t;
112+
float omt2 = omt * omt;
113+
dbasis[0] = -0.5f * omt2;
114+
dbasis[1] = -3.0f * omt + 2.75f * omt2;
115+
dbasis[2] = -3.0f + 9.0f * omt - 5.25f * omt2;
116+
dbasis[3] = 3.0f - 6.0f * omt + 3.0f * omt2;
117+
}
118+
}
119+
120+
121+
#endif

Gpufit/models/spline_2d.cuh

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,13 @@
4545
* user_info[3]: the number of spline intervals in Y
4646
* user_info[4]: the value of coefficient (0,0) of interval (0,0)
4747
* user_info[5]: the value of coefficient (1,0) of interval (0,0)
48+
* user_info[6]: the value of coefficient (2,0) of interval (0,0)
49+
* user_info[7]: the value of coefficient (3,0) of interval (0,0)
50+
* user_info[8]: the value of coefficient (0,1) of interval (0,0)
4851
* .
4952
* .
5053
* .
51-
* user_info[8]: the value of coefficient (0,1) of intervall (0,0)
52-
* .
53-
* .
54-
* .
55-
* user_info[20]: the value of coefficient (0,0) of intervall (1,0)
54+
* user_info[20]: the value of coefficient (0,0) of interval (1,0)
5655
* .
5756
* .
5857
* .

0 commit comments

Comments
 (0)