@@ -55,7 +55,6 @@ def codegen(fpsi, energy_name: str):
55
55
#include "pbat/math/linalg/mini/Matrix.h"
56
56
57
57
#include <cmath>
58
- #include <tuple>
59
58
60
59
namespace pbat {{
61
60
namespace physics {{
@@ -82,17 +81,17 @@ def codegen(fpsi, energy_name: str):
82
81
hesspsicode = cg .codegen (hesspsi .transpose (
83
82
), lhs = sp .MatrixSymbol ("H" , vecF .shape [0 ], vecF .shape [0 ]), scalar_type = "ScalarType" )
84
83
evalgradpsi = cg .codegen ([psi , gradpsi ], lhs = [sp .Symbol (
85
- "psi" ), sp .MatrixSymbol ("G " , * gradpsi .shape )], scalar_type = "ScalarType" )
84
+ "psi" ), sp .MatrixSymbol ("gF " , * gradpsi .shape )], scalar_type = "ScalarType" )
86
85
evalgradhesspsi = cg .codegen ([psi , gradpsi , hesspsi ], lhs = [
87
86
sp .Symbol ("psi" ),
88
- sp .MatrixSymbol ("G " , * gradpsi .shape ),
87
+ sp .MatrixSymbol ("gF " , * gradpsi .shape ),
89
88
sp .MatrixSymbol (
90
- "H " , vecF .shape [0 ], vecF .shape [0 ])
89
+ "HF " , vecF .shape [0 ], vecF .shape [0 ])
91
90
], scalar_type = "ScalarType" )
92
91
gradhesspsi = cg .codegen ([gradpsi , hesspsi ], lhs = [
93
- sp .MatrixSymbol ("G " , * gradpsi .shape ),
92
+ sp .MatrixSymbol ("gF " , * gradpsi .shape ),
94
93
sp .MatrixSymbol (
95
- "H " , vecF .shape [0 ], vecF .shape [0 ])
94
+ "HF " , vecF .shape [0 ], vecF .shape [0 ])
96
95
], scalar_type = "ScalarType" )
97
96
impl = f"""
98
97
template <>
@@ -131,39 +130,45 @@ def codegen(fpsi, energy_name: str):
131
130
typename TMatrix::ScalarType mu,
132
131
typename TMatrix::ScalarType lambda) const;
133
132
134
- template <math::linalg::mini::CReadableVectorizedMatrix TMatrix>
135
- PBAT_HOST_DEVICE
136
- std::tuple<
137
- typename TMatrix::ScalarType,
138
- SVector<typename TMatrix::ScalarType, { vecF .shape [0 ]} >
133
+ template <
134
+ math::linalg::mini::CReadableVectorizedMatrix TMatrix,
135
+ math::linalg::mini::CWriteableVectorizedMatrix TMatrixGF
139
136
>
137
+ PBAT_HOST_DEVICE
138
+ typename TMatrix::ScalarType
140
139
evalWithGrad(
141
140
TMatrix const& F,
142
141
typename TMatrix::ScalarType mu,
143
- typename TMatrix::ScalarType lambda) const;
142
+ typename TMatrix::ScalarType lambda,
143
+ TMatrixGF& gF) const;
144
144
145
- template <math::linalg::mini::CReadableVectorizedMatrix TMatrix>
146
- PBAT_HOST_DEVICE
147
- std::tuple<
148
- typename TMatrix::ScalarType,
149
- SVector<typename TMatrix::ScalarType, { vecF .shape [0 ]} >,
150
- SMatrix<typename TMatrix::ScalarType, { vecF .shape [0 ]} ,{ vecF .shape [0 ]} >
145
+ template <
146
+ math::linalg::mini::CReadableVectorizedMatrix TMatrix,
147
+ math::linalg::mini::CWriteableVectorizedMatrix TMatrixGF,
148
+ math::linalg::mini::CWriteableVectorizedMatrix TMatrixHF
151
149
>
150
+ PBAT_HOST_DEVICE
151
+ typename TMatrix::ScalarType
152
152
evalWithGradAndHessian(
153
153
TMatrix const& F,
154
154
typename TMatrix::ScalarType mu,
155
- typename TMatrix::ScalarType lambda) const;
156
-
157
- template <math::linalg::mini::CReadableVectorizedMatrix TMatrix>
158
- PBAT_HOST_DEVICE
159
- std::tuple<
160
- SVector<typename TMatrix::ScalarType, { vecF .shape [0 ]} >,
161
- SMatrix<typename TMatrix::ScalarType, { vecF .shape [0 ]} ,{ vecF .shape [0 ]} >
155
+ typename TMatrix::ScalarType lambda,
156
+ TMatrixGF& gF,
157
+ TMatrixHF& HF) const;
158
+
159
+ template <
160
+ math::linalg::mini::CReadableVectorizedMatrix TMatrix,
161
+ math::linalg::mini::CWriteableVectorizedMatrix TMatrixGF,
162
+ math::linalg::mini::CWriteableVectorizedMatrix TMatrixHF
162
163
>
164
+ PBAT_HOST_DEVICE
165
+ void
163
166
gradAndHessian(
164
167
TMatrix const& F,
165
168
typename TMatrix::ScalarType mu,
166
- typename TMatrix::ScalarType lambda) const;
169
+ typename TMatrix::ScalarType lambda,
170
+ TMatrixGF& gF,
171
+ TMatrixHF& HF) const;
167
172
}};
168
173
169
174
template <math::linalg::mini::CReadableVectorizedMatrix TMatrix>
@@ -208,60 +213,75 @@ def codegen(fpsi, energy_name: str):
208
213
return H;
209
214
}}
210
215
211
- template <math::linalg::mini::CReadableVectorizedMatrix TMatrix>
212
- PBAT_HOST_DEVICE
213
- std::tuple<
214
- typename TMatrix::ScalarType,
215
- { energy_name } <{ d } >::SVector<typename TMatrix::ScalarType, { vecF .shape [0 ]} >
216
+ template <
217
+ math::linalg::mini::CReadableVectorizedMatrix TMatrix,
218
+ math::linalg::mini::CWriteableVectorizedMatrix TMatrixGF
216
219
>
220
+ PBAT_HOST_DEVICE
221
+ typename TMatrix::ScalarType
217
222
{ energy_name } <{ d } >::evalWithGrad(
218
223
[[maybe_unused]] TMatrix const& F,
219
224
[[maybe_unused]] typename TMatrix::ScalarType mu,
220
- [[maybe_unused]] typename TMatrix::ScalarType lambda) const
225
+ [[maybe_unused]] typename TMatrix::ScalarType lambda,
226
+ TMatrixGF& gF) const
221
227
{{
228
+ static_assert(
229
+ TMatrixGF::kRows == { vecF .shape [0 ]} and TMatrixGF::kCols == 1,
230
+ "Grad w.r.t. F must have dimensions { vecF .shape [0 ]} x1");
222
231
using ScalarType = typename TMatrix::ScalarType;
223
232
ScalarType psi;
224
- SVector<ScalarType, { vecF .shape [0 ]} > G;
225
233
{ cg .tabulate (evalgradpsi , spaces = 4 )}
226
- return {{ psi, G}} ;
234
+ return psi;
227
235
}}
228
236
229
- template <math::linalg::mini::CReadableVectorizedMatrix TMatrix>
230
- PBAT_HOST_DEVICE
231
- std::tuple<
232
- typename TMatrix::ScalarType,
233
- { energy_name } <{ d } >::SVector<typename TMatrix::ScalarType, { vecF .shape [0 ]} >,
234
- { energy_name } <{ d } >::SMatrix<typename TMatrix::ScalarType, { vecF .shape [0 ]} ,{ vecF .shape [0 ]} >
237
+ template <
238
+ math::linalg::mini::CReadableVectorizedMatrix TMatrix,
239
+ math::linalg::mini::CWriteableVectorizedMatrix TMatrixGF,
240
+ math::linalg::mini::CWriteableVectorizedMatrix TMatrixHF
235
241
>
242
+ PBAT_HOST_DEVICE
243
+ typename TMatrix::ScalarType
236
244
{ energy_name } <{ d } >::evalWithGradAndHessian(
237
245
[[maybe_unused]] TMatrix const& F,
238
246
[[maybe_unused]] typename TMatrix::ScalarType mu,
239
- [[maybe_unused]] typename TMatrix::ScalarType lambda) const
247
+ [[maybe_unused]] typename TMatrix::ScalarType lambda,
248
+ TMatrixGF& gF,
249
+ TMatrixHF& HF) const
240
250
{{
251
+ static_assert(
252
+ TMatrixGF::kRows == { vecF .shape [0 ]} and TMatrixGF::kCols == 1,
253
+ "Grad w.r.t. F must have dimensions { vecF .shape [0 ]} x1");
254
+ static_assert(
255
+ TMatrixHF::kRows == { vecF .shape [0 ]} and TMatrixHF::kCols == { vecF .shape [0 ]} ,
256
+ "Hessian w.r.t. F must have dimensions { vecF .shape [0 ]} x{ vecF .shape [0 ]} ");
241
257
using ScalarType = typename TMatrix::ScalarType;
242
258
ScalarType psi;
243
- SVector<ScalarType, { vecF .shape [0 ]} > G;
244
- SMatrix<ScalarType, { vecF .shape [0 ]} ,{ vecF .shape [0 ]} > H;
245
259
{ cg .tabulate (evalgradhesspsi , spaces = 4 )}
246
- return {{ psi, G, H}} ;
260
+ return psi;
247
261
}}
248
262
249
- template <math::linalg::mini::CReadableVectorizedMatrix TMatrix>
250
- PBAT_HOST_DEVICE
251
- std::tuple<
252
- { energy_name } <{ d } >::SVector<typename TMatrix::ScalarType, { vecF .shape [0 ]} >,
253
- { energy_name } <{ d } >::SMatrix<typename TMatrix::ScalarType, { vecF .shape [0 ]} ,{ vecF .shape [0 ]} >
263
+ template <
264
+ math::linalg::mini::CReadableVectorizedMatrix TMatrix,
265
+ math::linalg::mini::CWriteableVectorizedMatrix TMatrixGF,
266
+ math::linalg::mini::CWriteableVectorizedMatrix TMatrixHF
254
267
>
268
+ PBAT_HOST_DEVICE
269
+ void
255
270
{ energy_name } <{ d } >::gradAndHessian(
256
271
[[maybe_unused]] TMatrix const& F,
257
272
[[maybe_unused]] typename TMatrix::ScalarType mu,
258
- [[maybe_unused]] typename TMatrix::ScalarType lambda) const
273
+ [[maybe_unused]] typename TMatrix::ScalarType lambda,
274
+ TMatrixGF& gF,
275
+ TMatrixHF& HF) const
259
276
{{
277
+ static_assert(
278
+ TMatrixGF::kRows == { vecF .shape [0 ]} and TMatrixGF::kCols == 1,
279
+ "Grad w.r.t. F must have dimensions { vecF .shape [0 ]} x1");
280
+ static_assert(
281
+ TMatrixHF::kRows == { vecF .shape [0 ]} and TMatrixHF::kCols == { vecF .shape [0 ]} ,
282
+ "Hessian w.r.t. F must have dimensions { vecF .shape [0 ]} x{ vecF .shape [0 ]} ");
260
283
using ScalarType = typename TMatrix::ScalarType;
261
- SVector<ScalarType, { vecF .shape [0 ]} > G;
262
- SMatrix<ScalarType, { vecF .shape [0 ]} ,{ vecF .shape [0 ]} > H;
263
284
{ cg .tabulate (gradhesspsi , spaces = 4 )}
264
- return {{G, H}};
265
285
}}
266
286
"""
267
287
source .append (impl )
0 commit comments