1+ """Module for Spline model"""
2+
3+ import torch
4+ import torch .nn as nn
5+ from ..utils import check_consistency
6+
7+ class Spline (torch .nn .Module ):
8+
9+ def __init__ (self , order = 4 , knots = None , control_points = None ) -> None :
10+ """
11+ Spline model.
12+
13+ :param int order: the order of the spline.
14+ :param torch.Tensor knots: the knot vector.
15+ :param torch.Tensor control_points: the control points.
16+ """
17+ super ().__init__ ()
18+
19+ check_consistency (order , int )
20+
21+ if order < 0 :
22+ raise ValueError ("Spline order cannot be negative." )
23+ if knots is None and control_points is None :
24+ raise ValueError ("Knots and control points cannot be both None." )
25+
26+ self .order = order
27+ self .k = order - 1
28+
29+ if knots is not None and control_points is not None :
30+ self .knots = knots
31+ self .control_points = control_points
32+
33+ elif knots is not None :
34+ print ('Warning: control points will be initialized automatically.' )
35+ print (' experimental feature' )
36+
37+ self .knots = knots
38+ n = len (knots ) - order
39+ self .control_points = torch .nn .Parameter (
40+ torch .zeros (n ), requires_grad = True )
41+
42+ elif control_points is not None :
43+ print ('Warning: knots will be initialized automatically.' )
44+ print (' experimental feature' )
45+
46+ self .control_points = control_points
47+
48+ n = len (self .control_points )- 1
49+ self .knots = {
50+ 'type' : 'auto' ,
51+ 'min' : 0 ,
52+ 'max' : 1 ,
53+ 'n' : n + 2 + self .order }
54+
55+ else :
56+ raise ValueError (
57+ "Knots and control points cannot be both None."
58+ )
59+
60+
61+ if self .knots .ndim != 1 :
62+ raise ValueError ("Knot vector must be one-dimensional." )
63+
64+ def basis (self , x , k , i , t ):
65+ '''
66+ Recursive function to compute the basis functions of the spline.
67+
68+ :param torch.Tensor x: points to be evaluated.
69+ :param int k: spline degree
70+ :param int i: the index of the interval
71+ :param torch.Tensor t: vector of knots
72+ :return: the basis functions evaluated at x
73+ :rtype: torch.Tensor
74+ '''
75+
76+ if k == 0 :
77+ a = torch .where (torch .logical_and (t [i ] <= x , x < t [i + 1 ]), 1.0 , 0.0 )
78+ if i == len (t ) - self .order - 1 :
79+ a = torch .where (x == t [- 1 ], 1.0 , a )
80+ a .requires_grad_ (True )
81+ return a
82+
83+
84+ if t [i + k ] == t [i ]:
85+ c1 = torch .tensor ([0.0 ]* len (x ), requires_grad = True )
86+ else :
87+ c1 = (x - t [i ])/ (t [i + k ] - t [i ]) * self .basis (x , k - 1 , i , t )
88+
89+ if t [i + k + 1 ] == t [i + 1 ]:
90+ c2 = torch .tensor ([0.0 ]* len (x ), requires_grad = True )
91+ else :
92+ c2 = (t [i + k + 1 ] - x )/ (t [i + k + 1 ] - t [i + 1 ]) * self .basis (x , k - 1 , i + 1 , t )
93+
94+ return c1 + c2
95+
96+
97+ @property
98+ def control_points (self ):
99+ return self ._control_points
100+
101+ @control_points .setter
102+ def control_points (self , value ):
103+ if isinstance (value , dict ):
104+ if 'n' not in value :
105+ raise ValueError ('Invalid value for control_points' )
106+ n = value ['n' ]
107+ dim = value .get ('dim' , 1 )
108+ value = torch .zeros (n , dim )
109+
110+ if not isinstance (value , torch .Tensor ):
111+ raise ValueError ('Invalid value for control_points' )
112+ self ._control_points = torch .nn .Parameter (value , requires_grad = True )
113+
114+ @property
115+ def knots (self ):
116+ return self ._knots
117+
118+ @knots .setter
119+ def knots (self , value ):
120+ if isinstance (value , dict ):
121+
122+ type_ = value .get ('type' , 'auto' )
123+ min_ = value .get ('min' , 0 )
124+ max_ = value .get ('max' , 1 )
125+ n = value .get ('n' , 10 )
126+
127+ if type_ == 'uniform' :
128+ value = torch .linspace (min_ , max_ , n + self .k + 1 )
129+ elif type_ == 'auto' :
130+ initial_knots = torch .ones (self .order + 1 )* min_
131+ final_knots = torch .ones (self .order + 1 )* max_
132+
133+ if n < self .order + 1 :
134+ value = torch .concatenate ((initial_knots , final_knots ))
135+ elif n - 2 * self .order + 1 == 1 :
136+ value = torch .Tensor ([(max_ + min_ )/ 2 ])
137+ else :
138+ value = torch .linspace (min_ , max_ , n - 2 * self .order - 1 )
139+
140+ value = torch .concatenate (
141+ (
142+ initial_knots , value , final_knots
143+ )
144+ )
145+
146+ if not isinstance (value , torch .Tensor ):
147+ raise ValueError ('Invalid value for knots' )
148+
149+ self ._knots = value
150+
151+ def forward (self , x_ ):
152+ """
153+ Forward pass of the spline model.
154+
155+ :param torch.Tensor x_: points to be evaluated.
156+ :return: the spline evaluated at x_
157+ :rtype: torch.Tensor
158+ """
159+ t = self .knots
160+ k = self .k
161+ c = self .control_points
162+
163+ basis = map (lambda i : self .basis (x_ , k , i , t )[:, None ], range (len (c )))
164+ y = (torch .cat (list (basis ), dim = 1 ) * c ).sum (axis = 1 )
165+
166+ return y
0 commit comments