33import torch
44import torch .nn as nn
55from ..utils import check_consistency
6-
6+
7+
78class Spline (torch .nn .Module ):
89
910 def __init__ (self , order = 4 , knots = None , control_points = None ) -> None :
@@ -31,38 +32,37 @@ def __init__(self, order=4, knots=None, control_points=None) -> None:
3132 self .control_points = control_points
3233
3334 elif knots is not None :
34- print (' Warning: control points will be initialized automatically.' )
35- print (' experimental feature' )
35+ print (" Warning: control points will be initialized automatically." )
36+ print (" experimental feature" )
3637
3738 self .knots = knots
3839 n = len (knots ) - order
3940 self .control_points = torch .nn .Parameter (
40- torch .zeros (n ), requires_grad = True )
41-
41+ torch .zeros (n ), requires_grad = True
42+ )
43+
4244 elif control_points is not None :
43- print (' Warning: knots will be initialized automatically.' )
44- print (' experimental feature' )
45-
45+ print (" Warning: knots will be initialized automatically." )
46+ print (" experimental feature" )
47+
4648 self .control_points = control_points
4749
48- n = len (self .control_points )- 1
50+ n = len (self .control_points ) - 1
4951 self .knots = {
50- 'type' : 'auto' ,
51- 'min' : 0 ,
52- 'max' : 1 ,
53- 'n' : n + 2 + self .order }
52+ "type" : "auto" ,
53+ "min" : 0 ,
54+ "max" : 1 ,
55+ "n" : n + 2 + self .order ,
56+ }
5457
5558 else :
56- raise ValueError (
57- "Knots and control points cannot be both None."
58- )
59-
59+ raise ValueError ("Knots and control points cannot be both None." )
6060
6161 if self .knots .ndim != 1 :
6262 raise ValueError ("Knot vector must be one-dimensional." )
6363
6464 def basis (self , x , k , i , t ):
65- '''
65+ """
6666 Recursive function to compute the basis functions of the spline.
6767
6868 :param torch.Tensor x: points to be evaluated.
@@ -71,28 +71,32 @@ def basis(self, x, k, i, t):
7171 :param torch.Tensor t: vector of knots
7272 :return: the basis functions evaluated at x
7373 :rtype: torch.Tensor
74- '''
75-
74+ """
75+
7676 if k == 0 :
77- a = torch .where (torch .logical_and (t [i ] <= x , x < t [i + 1 ]), 1.0 , 0.0 )
77+ a = torch .where (
78+ torch .logical_and (t [i ] <= x , x < t [i + 1 ]), 1.0 , 0.0
79+ )
7880 if i == len (t ) - self .order - 1 :
79- a = torch .where (x == t [- 1 ], 1.0 , a )
81+ a = torch .where (x == t [- 1 ], 1.0 , a )
8082 a .requires_grad_ (True )
8183 return a
8284
83-
84- if t [i + k ] == t [i ]:
85- c1 = torch .tensor ([0.0 ]* len (x ), requires_grad = True )
85+ if t [i + k ] == t [i ]:
86+ c1 = torch .tensor ([0.0 ] * len (x ), requires_grad = True )
8687 else :
87- c1 = (x - t [i ])/ (t [i + k ] - t [i ]) * self .basis (x , k - 1 , i , t )
88+ c1 = (x - t [i ]) / (t [i + k ] - t [i ]) * self .basis (x , k - 1 , i , t )
8889
89- if t [i + k + 1 ] == t [i + 1 ]:
90- c2 = torch .tensor ([0.0 ]* len (x ), requires_grad = True )
90+ if t [i + k + 1 ] == t [i + 1 ]:
91+ c2 = torch .tensor ([0.0 ] * len (x ), requires_grad = True )
9192 else :
92- c2 = (t [i + k + 1 ] - x )/ (t [i + k + 1 ] - t [i + 1 ]) * self .basis (x , k - 1 , i + 1 , t )
93+ c2 = (
94+ (t [i + k + 1 ] - x )
95+ / (t [i + k + 1 ] - t [i + 1 ])
96+ * self .basis (x , k - 1 , i + 1 , t )
97+ )
9398
9499 return c1 + c2
95-
96100
97101 @property
98102 def control_points (self ):
@@ -101,50 +105,46 @@ def control_points(self):
101105 @control_points .setter
102106 def control_points (self , value ):
103107 if isinstance (value , dict ):
104- if 'n' not in value :
105- raise ValueError (' Invalid value for control_points' )
106- n = value ['n' ]
107- dim = value .get (' dim' , 1 )
108+ if "n" not in value :
109+ raise ValueError (" Invalid value for control_points" )
110+ n = value ["n" ]
111+ dim = value .get (" dim" , 1 )
108112 value = torch .zeros (n , dim )
109113
110114 if not isinstance (value , torch .Tensor ):
111- raise ValueError (' Invalid value for control_points' )
115+ raise ValueError (" Invalid value for control_points" )
112116 self ._control_points = torch .nn .Parameter (value , requires_grad = True )
113117
114118 @property
115119 def knots (self ):
116120 return self ._knots
117-
121+
118122 @knots .setter
119123 def knots (self , value ):
120124 if isinstance (value , dict ):
121125
122- type_ = value .get (' type' , ' auto' )
123- min_ = value .get (' min' , 0 )
124- max_ = value .get (' max' , 1 )
125- n = value .get ('n' , 10 )
126+ type_ = value .get (" type" , " auto" )
127+ min_ = value .get (" min" , 0 )
128+ max_ = value .get (" max" , 1 )
129+ n = value .get ("n" , 10 )
126130
127- if type_ == ' uniform' :
131+ if type_ == " uniform" :
128132 value = torch .linspace (min_ , max_ , n + self .k + 1 )
129- elif type_ == ' auto' :
130- initial_knots = torch .ones (self .order + 1 ) * min_
131- final_knots = torch .ones (self .order + 1 ) * max_
133+ elif type_ == " auto" :
134+ initial_knots = torch .ones (self .order + 1 ) * min_
135+ final_knots = torch .ones (self .order + 1 ) * max_
132136
133137 if n < self .order + 1 :
134138 value = torch .concatenate ((initial_knots , final_knots ))
135- elif n - 2 * self .order + 1 == 1 :
136- value = torch .Tensor ([(max_ + min_ )/ 2 ])
139+ elif n - 2 * self .order + 1 == 1 :
140+ value = torch .Tensor ([(max_ + min_ ) / 2 ])
137141 else :
138- value = torch .linspace (min_ , max_ , n - 2 * self .order - 1 )
142+ value = torch .linspace (min_ , max_ , n - 2 * self .order - 1 )
139143
140- value = torch .concatenate (
141- (
142- initial_knots , value , final_knots
143- )
144- )
144+ value = torch .concatenate ((initial_knots , value , final_knots ))
145145
146146 if not isinstance (value , torch .Tensor ):
147- raise ValueError (' Invalid value for knots' )
147+ raise ValueError (" Invalid value for knots" )
148148
149149 self ._knots = value
150150
@@ -154,7 +154,7 @@ def forward(self, x_):
154154
155155 :param torch.Tensor x_: points to be evaluated.
156156 :return: the spline evaluated at x_
157- :rtype: torch.Tensor
157+ :rtype: torch.Tensor
158158 """
159159 t = self .knots
160160 k = self .k
@@ -163,4 +163,4 @@ def forward(self, x_):
163163 basis = map (lambda i : self .basis (x_ , k , i , t )[:, None ], range (len (c )))
164164 y = (torch .cat (list (basis ), dim = 1 ) * c ).sum (axis = 1 )
165165
166- return y
166+ return y
0 commit comments