@@ -70,6 +70,7 @@ void MassSpringEnergy<T, dim>::update_k(const std::vector<T> &k)
7070 pimpl_->device_k .copy_from (k);
7171}
7272
73+ // ANCHOR: val
7374template <typename T, int dim>
7475T MassSpringEnergy<T, dim>::val()
7576{
@@ -81,19 +82,21 @@ T MassSpringEnergy<T, dim>::val()
8182 DeviceBuffer<T> device_val (N);
8283 ParallelFor (256 ).apply (N, [device_val = device_val.viewer (), device_x = device_x.cviewer (), device_e = device_e.cviewer (), device_l2 = device_l2.cviewer (), device_k = device_k.cviewer ()] __device__ (int i) mutable
8384 {
84- int idx1= device_e (2 * i); // First node index
85- int idx2 = device_e (2 * i + 1 ); // Second node index
86- T diff = 0 ;
87- for (int d = 0 ; d < dim;d++){
88- T diffi = device_x (dim * idx1 + d) - device_x (dim * idx2 + d);
89- diff += diffi * diffi;
90- }
91- device_val (i) = 0.5 * device_l2 (i) * device_k (i) * (diff / device_l2 (i) - 1 ) * (diff / device_l2 (i) - 1 ); })
85+ int idx1= device_e (2 * i); // First node index
86+ int idx2 = device_e (2 * i + 1 ); // Second node index
87+ T diff = 0 ;
88+ for (int d = 0 ; d < dim;d++){
89+ T diffi = device_x (dim * idx1 + d) - device_x (dim * idx2 + d);
90+ diff += diffi * diffi;
91+ }
92+ device_val (i) = 0.5 * device_l2 (i) * device_k (i) * (diff / device_l2 (i) - 1 ) * (diff / device_l2 (i) - 1 ); })
9293 .wait ();
9394
9495 return devicesum (device_val);
9596} // Calculate the energy
97+ // ANCHOR_END: val
9698
99+ // ANCHOR: grad
97100template <typename T, int dim>
98101const DeviceBuffer<T> &MassSpringEnergy<T, dim>::grad()
99102{
@@ -106,25 +109,26 @@ const DeviceBuffer<T> &MassSpringEnergy<T, dim>::grad()
106109 device_grad.fill (0 );
107110 ParallelFor (256 ).apply (N, [device_x = device_x.cviewer (), device_e = device_e.cviewer (), device_l2 = device_l2.cviewer (), device_k = device_k.cviewer (), device_grad = device_grad.viewer ()] __device__ (int i) mutable
108111 {
109- int idx1= device_e (2 * i); // First node index
110- int idx2 = device_e (2 * i + 1 ); // Second node index
111- T diff = 0 ;
112- T diffi[dim];
113- for (int d = 0 ; d < dim;d++){
114- diffi[d] = device_x (dim * idx1 + d) - device_x (dim * idx2 + d);
115- diff += diffi[d] * diffi[d];
116- }
117- T factor = 2 * device_k (i) * (diff / device_l2 (i) -1 );
118- for (int d=0 ;d<dim;d++){
119- atomicAdd (&device_grad (dim * idx1 + d), factor * diffi[d]);
120- atomicAdd (&device_grad (dim * idx2 + d), -factor * diffi[d]);
121-
122- } })
112+ int idx1= device_e (2 * i); // First node index
113+ int idx2 = device_e (2 * i + 1 ); // Second node index
114+ T diff = 0 ;
115+ T diffi[dim];
116+ for (int d = 0 ; d < dim;d++){
117+ diffi[d] = device_x (dim * idx1 + d) - device_x (dim * idx2 + d);
118+ diff += diffi[d] * diffi[d];
119+ }
120+ T factor = 2 * device_k (i) * (diff / device_l2 (i) -1 );
121+ for (int d=0 ;d<dim;d++){
122+ atomicAdd (&device_grad (dim * idx1 + d), factor * diffi[d]);
123+ atomicAdd (&device_grad (dim * idx2 + d), -factor * diffi[d]);
124+ } })
123125 .wait ();
124126 // display_vec(device_grad);
125127 return device_grad;
126128}
129+ // ANCHOR_END: grad
127130
131+ // ANCHOR: hess
128132template <typename T, int dim>
129133const DeviceTripletMatrix<T, 1 > &MassSpringEnergy<T, dim>::hess()
130134{
@@ -170,8 +174,8 @@ const DeviceTripletMatrix<T, 1> &MassSpringEnergy<T, dim>::hess()
170174 } })
171175 .wait ();
172176 return device_hess;
173-
174177} // Calculate the Hessian of the energy
178+ // ANCHOR_END: hess
175179
176180template class MassSpringEnergy <float , 2 >;
177181template class MassSpringEnergy <float , 3 >;
0 commit comments