Skip to content

Commit 071a03a

Browse files
committed
updated with latest N-D natural B-Spline fit model and example code
1 parent b9b39e2 commit 071a03a

File tree

7 files changed

+1380
-150
lines changed

7 files changed

+1380
-150
lines changed

Gpufit/constants.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@ enum ModelID {
2020
SPLINE_4D = 13,
2121
SPLINE_4D_MULTICHANNEL = 14,
2222
SPLINE_5D = 15,
23-
NATURAL_BSPLINE_1D = 16
23+
NATURAL_BSPLINE_1D = 16,
24+
NATURAL_BSPLINE_2D = 17,
25+
NATURAL_BSPLINE_3D = 18,
26+
NATURAL_BSPLINE_4D = 19,
27+
NATURAL_BSPLINE_5D = 20,
28+
NATURAL_BSPLINE_6D = 21
2429
};
2530

2631
// estimator ID

Gpufit/models/models.cuh

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "spline_4d_multichannel.cuh"
2020
#include "spline_5d.cuh"
2121
#include "natural_bspline_1d.cuh"
22+
#include "natural_bspline_nd.cuh"
2223

2324
__device__ void calculate_model(
2425
ModelID const model_id,
@@ -86,6 +87,21 @@ __device__ void calculate_model(
8687
case NATURAL_BSPLINE_1D:
8788
calculate_natural_bspline1d(parameters, n_fits, n_points, value, derivative, point_index, fit_index, chunk_index, user_info, user_info_size);
8889
break;
90+
case NATURAL_BSPLINE_2D:
91+
calculate_natural_bspline_nd(parameters, n_fits, n_points, value, derivative, point_index, fit_index, chunk_index, user_info, user_info_size);
92+
break;
93+
case NATURAL_BSPLINE_3D:
94+
calculate_natural_bspline_nd(parameters, n_fits, n_points, value, derivative, point_index, fit_index, chunk_index, user_info, user_info_size);
95+
break;
96+
case NATURAL_BSPLINE_4D:
97+
calculate_natural_bspline_nd(parameters, n_fits, n_points, value, derivative, point_index, fit_index, chunk_index, user_info, user_info_size);
98+
break;
99+
case NATURAL_BSPLINE_5D:
100+
calculate_natural_bspline_nd(parameters, n_fits, n_points, value, derivative, point_index, fit_index, chunk_index, user_info, user_info_size);
101+
break;
102+
case NATURAL_BSPLINE_6D:
103+
calculate_natural_bspline_nd(parameters, n_fits, n_points, value, derivative, point_index, fit_index, chunk_index, user_info, user_info_size);
104+
break;
89105
default:
90106
assert(0); // unknown model ID
91107
}
@@ -112,6 +128,11 @@ void configure_model(ModelID const model_id, int & n_parameters, int & n_dimensi
112128
case SPLINE_4D_MULTICHANNEL: n_parameters = 6; n_dimensions = 5; break;
113129
case SPLINE_5D: n_parameters = 7; n_dimensions = 5; break;
114130
case NATURAL_BSPLINE_1D: n_parameters = 3; n_dimensions = 1; break;
131+
case NATURAL_BSPLINE_2D: n_parameters = 4; n_dimensions = 2; break;
132+
case NATURAL_BSPLINE_3D: n_parameters = 5; n_dimensions = 3; break;
133+
case NATURAL_BSPLINE_4D: n_parameters = 6; n_dimensions = 4; break;
134+
case NATURAL_BSPLINE_5D: n_parameters = 7; n_dimensions = 5; break;
135+
case NATURAL_BSPLINE_6D: n_parameters = 8; n_dimensions = 6; break;
115136
default: throw std::runtime_error("unknown model ID");
116137
}
117138
}

Gpufit/models/natural_bspline_1d.cuh

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
*
1919
* parameters: An input vector of concatenated sets of model parameters.
2020
* p[0]: amplitude
21-
* p[1]: center coordinate
21+
* p[1]: x-shift
2222
* p[2]: offset
2323
*
2424
* n_fits: The number of fits.
@@ -60,30 +60,40 @@ __device__ void calculate_natural_bspline1d(
6060
REAL const * ui = (REAL*)user_info;
6161

6262
int const num_coeff = static_cast<int>(ui[0]);
63+
int const N_tpl = num_coeff - 2; // template length
6364
int const num_knots = num_coeff + 4; // cubic
6465
REAL const * knots = ui + 1;
6566
REAL const * coeff = ui + 1 + num_knots;
6667

67-
// Model parameters: [amp, center, offset]
68+
// Model parameters: [amp, shift, offset]
6869
REAL const * p = parameters;
6970
REAL amp = p[0];
70-
REAL center = p[1];
71+
REAL shift = p[1];
7172
REAL offset = p[2];
7273

7374
// Data point coordinate
7475
REAL x = static_cast<REAL>(point_index);
75-
REAL xq = x - center;
76+
REAL xq = x - p[1];
77+
78+
bool clamped = false;
79+
if (xq < (REAL)0) { xq = (REAL)0; clamped = true; }
80+
if (xq > (REAL)(N_tpl - 1)) { xq = (REAL)(N_tpl - 1); clamped = true; }
7681

7782
// Find knot span as in host code
7883
const int k = 3;
79-
int N = num_coeff - 2;
8084
int span;
81-
if (xq <= 0.0)
85+
if (xq <= (REAL)0)
86+
{
8287
span = k;
83-
else if (xq >= REAL(N - 1))
88+
}
89+
else if (xq >= (REAL)(N_tpl - 1))
90+
{
8491
span = num_coeff - 1;
92+
}
8593
else
86-
span = int(xq) + k;
94+
{
95+
span = (int)xq + k;
96+
}
8797

8898
// Evaluate basis and derivative (device function!)
8999
REAL basis[4], dbasis[4];
@@ -102,13 +112,15 @@ __device__ void calculate_natural_bspline1d(
102112
}
103113
}
104114

115+
if (clamped) spline_dx = (REAL)0;
116+
105117
// Write value
106118
value[point_index] = amp * spline_val + offset;
107119

108-
// Write derivatives [amp, center, offset]
120+
// Write derivatives [amp, shift, offset]
109121
REAL * der = derivative + point_index;
110122
der[0 * n_points] = spline_val; // d/d(amp)
111-
der[1 * n_points] = -amp * spline_dx; // d/d(center)
123+
der[1 * n_points] = -amp * spline_dx; // d/d(shift)
112124
der[2 * n_points] = 1; // d/d(offset)
113125
}
114126

0 commit comments

Comments
 (0)