|
4 | 4 | #include "Api.h"
|
5 | 5 | #include "Concepts.h"
|
6 | 6 | #include "pbat/HostDevice.h"
|
| 7 | +#include "pbat/common/ConstexprFor.h" |
7 | 8 |
|
8 | 9 | #include <array>
|
9 | 10 | #include <initializer_list>
|
@@ -233,6 +234,111 @@ PBAT_HOST_DEVICE auto Unit(auto i)
|
233 | 234 | return Identity<TScalar, M, M>().Col(i);
|
234 | 235 | }
|
235 | 236 |
|
| 237 | +template <int M, int N, class TScalar, class IndexType> |
| 238 | +PBAT_HOST_DEVICE auto FromFlatBuffer(TScalar* buf, IndexType bi) |
| 239 | +{ |
| 240 | + return SMatrixView<TScalar, M, N>(buf + M * N * bi); |
| 241 | +} |
| 242 | + |
| 243 | +template <class TScalar, CMatrix TIndexMatrix> |
| 244 | +PBAT_HOST_DEVICE auto FromFlatBuffer(TScalar* buf, TIndexMatrix const& inds) |
| 245 | +{ |
| 246 | + using IntegerType = typename TIndexMatrix::ScalarType; |
| 247 | + static_assert(std::is_integral_v<IntegerType>, "inds must be matrix of indices"); |
| 248 | + auto constexpr M = TIndexMatrix::kRows; |
| 249 | + auto constexpr N = TIndexMatrix::kCols; |
| 250 | + SMatrix<std::remove_const_t<TScalar>, M, N> A{}; |
| 251 | + using pbat::common::ForRange; |
| 252 | + ForRange<0, N>([&]<auto j>() { ForRange<0, M>([&]<auto i>() { A(i, j) = buf[inds(i, j)]; }); }); |
| 253 | + return A; |
| 254 | +} |
| 255 | + |
| 256 | +template <CMatrix TMatrix, class IndexType> |
| 257 | +PBAT_HOST_DEVICE void |
| 258 | +ToFlatBuffer(TMatrix const& A, typename TMatrix::ScalarType* buf, IndexType bi) |
| 259 | +{ |
| 260 | + auto constexpr M = TMatrix::kRows; |
| 261 | + auto constexpr N = TMatrix::kCols; |
| 262 | + FromFlatBuffer<M, N>(buf, bi) = A; |
| 263 | +} |
| 264 | + |
| 265 | +template <CMatrix TMatrix, CMatrix TIndexMatrix> |
| 266 | +PBAT_HOST_DEVICE void |
| 267 | +ToFlatBuffer(TMatrix const& A, TIndexMatrix const& inds, typename TMatrix::ScalarType* buf) |
| 268 | +{ |
| 269 | + auto constexpr MA = TMatrix::kRows; |
| 270 | + auto constexpr NA = TMatrix::kCols; |
| 271 | + auto constexpr MI = TIndexMatrix::kRows; |
| 272 | + auto constexpr NI = TIndexMatrix::kCols; |
| 273 | + static_assert(MA == MI or MI == 1, "A must have same rows as inds or inds is a row vector"); |
| 274 | + static_assert(NA == NI, "A must have same cols as inds"); |
| 275 | + using pbat::common::ForRange; |
| 276 | + if constexpr (MA > 1 and MI == 1) |
| 277 | + { |
| 278 | + // In this case, I will assume that the user wishes to put each column of A in the |
| 279 | + // corresponding "column" in the flat buffer buf, as if column major, according to inds. |
| 280 | + ForRange<0, NA>([&]<auto j>() { |
| 281 | + ForRange<0, MA>([&]<auto i>() { buf[MA * inds(0, j) + i] = A(i, j); }); |
| 282 | + }); |
| 283 | + } |
| 284 | + else |
| 285 | + { |
| 286 | + ForRange<0, NA>( |
| 287 | + [&]<auto j>() { ForRange<0, MA>([&]<auto i>() { buf[inds(i, j)] = A(i, j); }); }); |
| 288 | + } |
| 289 | +} |
| 290 | + |
| 291 | +template <int M, int N, class TScalar, class IndexType> |
| 292 | +PBAT_HOST_DEVICE auto |
| 293 | +FromBuffers([[maybe_unused]] std::array<TScalar*, M> buf, [[maybe_unused]] IndexType bi) |
| 294 | +{ |
| 295 | + using ScalarType = std::remove_const_t<TScalar>; |
| 296 | + SMatrix<ScalarType, M, N> A{}; |
| 297 | + using pbat::common::ForRange; |
| 298 | + ForRange<0, M>([&]<auto i>() { A.Row(i) = FromFlatBuffer<1, N>(buf[i], bi); }); |
| 299 | + return A; |
| 300 | +} |
| 301 | + |
| 302 | +template <int K, class TScalar, CMatrix TIndexMatrix> |
| 303 | +PBAT_HOST_DEVICE auto FromBuffers(std::array<TScalar*, K> buf, TIndexMatrix const& inds) |
| 304 | +{ |
| 305 | + using IntegerType = typename TIndexMatrix::ScalarType; |
| 306 | + static_assert(std::is_integral_v<IntegerType>, "inds must be matrix of indices"); |
| 307 | + auto constexpr M = TIndexMatrix::kRows; |
| 308 | + auto constexpr N = TIndexMatrix::kCols; |
| 309 | + SMatrix<std::remove_const_t<TScalar>, K * M, N> A{}; |
| 310 | + using pbat::common::ForRange; |
| 311 | + ForRange<0, K>([&]<auto k>() { A.Slice<M, N>(k * M, 0) = FromFlatBuffer(buf[k], inds); }); |
| 312 | + return A; |
| 313 | +} |
| 314 | + |
| 315 | +template <CMatrix TMatrix, int M, class IndexType> |
| 316 | +PBAT_HOST_DEVICE void |
| 317 | +ToBuffers(TMatrix const& A, std::array<typename TMatrix::ScalarType*, M> buf, IndexType bi) |
| 318 | +{ |
| 319 | + static_assert(M == TMatrix::kRows, "A must have same rows as number of buffers"); |
| 320 | + auto constexpr N = TMatrix::kCols; |
| 321 | + using pbat::common::ForRange; |
| 322 | + ForRange<0, M>([&]<auto i>() { FromFlatBuffer<1, N>(buf[i], bi) = A.Row(i); }); |
| 323 | +} |
| 324 | + |
| 325 | +template <CMatrix TMatrix, CMatrix TIndexMatrix, int K> |
| 326 | +PBAT_HOST_DEVICE void ToBuffers( |
| 327 | + TMatrix const& A, |
| 328 | + TIndexMatrix const& inds, |
| 329 | + std::array<typename TMatrix::ScalarType*, K> buf) |
| 330 | +{ |
| 331 | + auto constexpr MA = TMatrix::kRows; |
| 332 | + auto constexpr NA = TMatrix::kCols; |
| 333 | + auto constexpr MI = TIndexMatrix::kRows; |
| 334 | + auto constexpr NI = TIndexMatrix::kCols; |
| 335 | + static_assert(MA % MI == 0, "Rows of A must be multiple of rows of inds"); |
| 336 | + static_assert(NA == NI, "A and inds must have same number of columns"); |
| 337 | + static_assert(MA / MI == K, "A must have number of rows == #buffers*#rows of inds"); |
| 338 | + using pbat::common::ForRange; |
| 339 | + ForRange<0, K>([&]<auto k>() { ToFlatBuffer(A.Slice<MI, NI>(k * MI, 0), inds, buf[k]); }); |
| 340 | +} |
| 341 | + |
236 | 342 | } // namespace mini
|
237 | 343 | } // namespace linalg
|
238 | 344 | } // namespace math
|
|
0 commit comments