Skip to content

Commit 546fa4b

Browse files
committed
Implement convenience functions for retrieving/storing mini objects to global buffers
1 parent 57516f1 commit 546fa4b

File tree

8 files changed

+415
-9
lines changed

8 files changed

+415
-9
lines changed

source/pbat/gpu/math/linalg/Matrix.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@ namespace test {
2121
template <class Func>
2222
pbat::GpuMatrixX RunKernel(pbat::GpuMatrixX const& A)
2323
{
24-
using namespace pbat::gpu;
2524
auto const toGpu = [](GpuMatrixX const& A) {
26-
common::Buffer<GpuScalar> buf(A.size());
25+
gpu::common::Buffer<GpuScalar> buf(A.size());
2726
thrust::copy(A.data(), A.data() + A.size(), buf.Data());
2827
return buf;
2928
};
30-
auto const fromGpu = [](common::Buffer<GpuScalar> const& buf, auto rows, auto cols) {
29+
auto const fromGpu = [](gpu::common::Buffer<GpuScalar> const& buf, auto rows, auto cols) {
3130
GpuMatrixX A(rows, cols);
3231
thrust::copy(buf.Data(), buf.Data() + buf.Size(), A.data());
3332
return A;

source/pbat/math/linalg/mini/BinaryOperations.h

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,34 @@ class Sum
5151
RhsNestedType const& B;
5252
};
5353

54+
template <CMatrix TLhsMatrix>
55+
class SumScalar
56+
{
57+
public:
58+
using LhsNestedType = TLhsMatrix;
59+
60+
using ScalarType = typename LhsNestedType::ScalarType;
61+
using SelfType = SumScalar<LhsNestedType>;
62+
63+
static auto constexpr kRows = LhsNestedType::kRows;
64+
static auto constexpr kCols = LhsNestedType::kCols;
65+
static bool constexpr bRowMajor = LhsNestedType::bRowMajor;
66+
67+
PBAT_HOST_DEVICE SumScalar(LhsNestedType const& A, ScalarType k) : mA(A), mK(k) {}
68+
69+
PBAT_HOST_DEVICE ScalarType operator()(auto i, auto j) const { return mA(i, j) + mK; }
70+
71+
// Vector(ized) access
72+
PBAT_HOST_DEVICE ScalarType operator()(auto i) const { return (*this)(i % kRows, i / kRows); }
73+
PBAT_HOST_DEVICE ScalarType operator[](auto i) const { return (*this)(i); }
74+
75+
PBAT_MINI_READ_API(SelfType)
76+
77+
private:
78+
LhsNestedType const& mA;
79+
ScalarType mK;
80+
};
81+
5482
template <CMatrix TLhsMatrix, CMatrix TRhsMatrix>
5583
class Subtraction
5684
{
@@ -86,6 +114,33 @@ class Subtraction
86114
RhsNestedType const& B;
87115
};
88116

117+
template <CMatrix TLhsMatrix>
118+
class SubtractionScalar
119+
{
120+
public:
121+
using LhsNestedType = TLhsMatrix;
122+
using ScalarType = typename LhsNestedType::ScalarType;
123+
using SelfType = SubtractionScalar<LhsNestedType>;
124+
125+
static auto constexpr kRows = LhsNestedType::kRows;
126+
static auto constexpr kCols = LhsNestedType::kCols;
127+
static bool constexpr bRowMajor = LhsNestedType::bRowMajor;
128+
129+
PBAT_HOST_DEVICE SubtractionScalar(LhsNestedType const& A, ScalarType k) : mA(A), mK(k) {}
130+
131+
PBAT_HOST_DEVICE ScalarType operator()(auto i, auto j) const { return mA(i, j) - mK; }
132+
133+
// Vector(ized) access
134+
PBAT_HOST_DEVICE ScalarType operator()(auto i) const { return (*this)(i % kRows, i / kRows); }
135+
PBAT_HOST_DEVICE ScalarType operator[](auto i) const { return (*this)(i); }
136+
137+
PBAT_MINI_READ_API(SelfType)
138+
139+
private:
140+
LhsNestedType const& mA;
141+
ScalarType mK;
142+
};
143+
89144
template <CMatrix TLhsMatrix, CMatrix TRhsMatrix>
90145
class Minimum
91146
{
@@ -245,9 +300,16 @@ PBAT_HOST_DEVICE auto operator+(TLhsMatrix&& A, TRhsMatrix&& B)
245300
{
246301
using LhsMatrixType = std::remove_cvref_t<TLhsMatrix>;
247302
using RhsMatrixType = std::remove_cvref_t<TRhsMatrix>;
248-
return Sum<LhsMatrixType, RhsMatrixType>(
249-
std::forward<TLhsMatrix>(A),
250-
std::forward<TRhsMatrix>(B));
303+
if constexpr (std::is_arithmetic_v<RhsMatrixType>)
304+
{
305+
return SumScalar<LhsMatrixType>(std::forward<TLhsMatrix>(A), std::forward<TRhsMatrix>(B));
306+
}
307+
else
308+
{
309+
return Sum<LhsMatrixType, RhsMatrixType>(
310+
std::forward<TLhsMatrix>(A),
311+
std::forward<TRhsMatrix>(B));
312+
}
251313
}
252314

253315
template <class /*CMatrix*/ TLhsMatrix, class /*CMatrix*/ TRhsMatrix>
@@ -262,9 +324,18 @@ PBAT_HOST_DEVICE auto operator-(TLhsMatrix&& A, TRhsMatrix&& B)
262324
{
263325
using LhsMatrixType = std::remove_cvref_t<TLhsMatrix>;
264326
using RhsMatrixType = std::remove_cvref_t<TRhsMatrix>;
265-
return Subtraction<LhsMatrixType, RhsMatrixType>(
266-
std::forward<TLhsMatrix>(A),
267-
std::forward<TRhsMatrix>(B));
327+
if constexpr (std::is_arithmetic_v<RhsMatrixType>)
328+
{
329+
return SubtractionScalar<LhsMatrixType>(
330+
std::forward<TLhsMatrix>(A),
331+
std::forward<TRhsMatrix>(B));
332+
}
333+
else
334+
{
335+
return Subtraction<LhsMatrixType, RhsMatrixType>(
336+
std::forward<TLhsMatrix>(A),
337+
std::forward<TRhsMatrix>(B));
338+
}
268339
}
269340

270341
template <class /*CMatrix*/ TLhsMatrix, class /*CMatrix*/ TRhsMatrix>

source/pbat/math/linalg/mini/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ target_sources(PhysicsBasedAnimationToolkit_PhysicsBasedAnimationToolkit
1717
"Reductions.h"
1818
"Repeat.h"
1919
"Scale.h"
20+
"Stack.h"
2021
"SubMatrix.h"
2122
"Transpose.h"
2223
"UnaryOperations.h"
@@ -35,6 +36,7 @@ target_sources(PhysicsBasedAnimationToolkit_PhysicsBasedAnimationToolkit
3536
"Reductions.cpp"
3637
"Repeat.cpp"
3738
"Scale.cpp"
39+
"Stack.cpp"
3840
"SubMatrix.cpp"
3941
"Transpose.cpp"
4042
"UnaryOperations.cpp"

source/pbat/math/linalg/mini/Matrix.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#include "Matrix.h"
22

3+
#include "BinaryOperations.h"
4+
#include "Reductions.h"
5+
#include "Repeat.h"
6+
#include "Stack.h"
37
#include "pbat/Aliases.h"
48

59
#include <doctest/doctest.h>
@@ -25,4 +29,61 @@ TEST_CASE("[math][linalg][mini] Matrix")
2529
using MatrixViewType = SMatrixView<ScalarType, kRows, kCols>;
2630
PBAT_MINI_CHECK_READABLE_CONCEPTS(MatrixViewType);
2731
PBAT_MINI_CHECK_WRITEABLE_CONCEPTS(MatrixViewType);
32+
33+
pbat::MatrixX data(2, 10);
34+
data.leftCols(5).array() = ScalarType(1);
35+
data.rightCols(5).array() = ScalarType(2);
36+
auto flatBufViewLeft = FromFlatBuffer<2, 5>(data.data(), 0);
37+
auto flatBufViewRight = FromFlatBuffer<2, 5>(data.data(), 1);
38+
CHECK(All(flatBufViewLeft == ScalarType(1)));
39+
CHECK(All(flatBufViewRight == ScalarType(2)));
40+
41+
pbat::VectorX dataTop = data.row(0);
42+
pbat::VectorX dataBottom = data.row(1);
43+
std::array<ScalarType*, 2> bufs{dataTop.data(), dataBottom.data()};
44+
auto bufViewLeft = FromBuffers<2, 5>(bufs, 0);
45+
auto bufViewRight = FromBuffers<2, 5>(bufs, 1);
46+
CHECK(All(flatBufViewLeft == bufViewLeft));
47+
CHECK(All(flatBufViewRight == bufViewRight));
48+
ToBuffers(bufViewLeft + Ones<ScalarType, 2, 5>(), bufs, 0);
49+
ToBuffers(bufViewRight + Ones<ScalarType, 2, 5>(), bufs, 1);
50+
for (auto i = 0; i < 5; ++i)
51+
{
52+
CHECK_EQ(dataTop(i), ScalarType(2));
53+
CHECK_EQ(dataBottom(i), ScalarType(2));
54+
}
55+
for (auto i = 5; i < 10; ++i)
56+
{
57+
CHECK_EQ(dataTop(i), ScalarType(3));
58+
CHECK_EQ(dataBottom(i), ScalarType(3));
59+
}
60+
61+
ToFlatBuffer(flatBufViewLeft + Ones<ScalarType, 2, 5>(), data.data(), 0);
62+
ToFlatBuffer(flatBufViewRight + Ones<ScalarType, 2, 5>(), data.data(), 1);
63+
for (auto i = 0; i < 5; ++i)
64+
{
65+
CHECK_EQ(data(0, i), ScalarType(2));
66+
CHECK_EQ(data(1, i), ScalarType(2));
67+
}
68+
for (auto i = 5; i < 10; ++i)
69+
{
70+
CHECK_EQ(data(0, i), ScalarType(3));
71+
CHECK_EQ(data(1, i), ScalarType(3));
72+
}
73+
74+
using IndexType = int;
75+
SMatrix<IndexType, 1, 5> inds{0, 2, 4, 6, 8};
76+
ToFlatBuffer(Ones<ScalarType, 2, 5>(), inds, data.data());
77+
for (auto i = 0; i < 10; i += 2)
78+
{
79+
CHECK_EQ(data(0, i), ScalarType(1));
80+
CHECK_EQ(data(1, i), ScalarType(1));
81+
}
82+
83+
ToBuffers(Ones<ScalarType, 2, 5>(), inds, bufs);
84+
for (auto i = 0; i < 10; i += 2)
85+
{
86+
CHECK_EQ(dataTop(i), ScalarType(1));
87+
CHECK_EQ(dataBottom(i), ScalarType(1));
88+
}
2889
}

source/pbat/math/linalg/mini/Matrix.h

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "Api.h"
55
#include "Concepts.h"
66
#include "pbat/HostDevice.h"
7+
#include "pbat/common/ConstexprFor.h"
78

89
#include <array>
910
#include <initializer_list>
@@ -233,6 +234,111 @@ PBAT_HOST_DEVICE auto Unit(auto i)
233234
return Identity<TScalar, M, M>().Col(i);
234235
}
235236

237+
template <int M, int N, class TScalar, class IndexType>
238+
PBAT_HOST_DEVICE auto FromFlatBuffer(TScalar* buf, IndexType bi)
239+
{
240+
return SMatrixView<TScalar, M, N>(buf + M * N * bi);
241+
}
242+
243+
template <class TScalar, CMatrix TIndexMatrix>
244+
PBAT_HOST_DEVICE auto FromFlatBuffer(TScalar* buf, TIndexMatrix const& inds)
245+
{
246+
using IntegerType = typename TIndexMatrix::ScalarType;
247+
static_assert(std::is_integral_v<IntegerType>, "inds must be matrix of indices");
248+
auto constexpr M = TIndexMatrix::kRows;
249+
auto constexpr N = TIndexMatrix::kCols;
250+
SMatrix<std::remove_const_t<TScalar>, M, N> A{};
251+
using pbat::common::ForRange;
252+
ForRange<0, N>([&]<auto j>() { ForRange<0, M>([&]<auto i>() { A(i, j) = buf[inds(i, j)]; }); });
253+
return A;
254+
}
255+
256+
template <CMatrix TMatrix, class IndexType>
257+
PBAT_HOST_DEVICE void
258+
ToFlatBuffer(TMatrix const& A, typename TMatrix::ScalarType* buf, IndexType bi)
259+
{
260+
auto constexpr M = TMatrix::kRows;
261+
auto constexpr N = TMatrix::kCols;
262+
FromFlatBuffer<M, N>(buf, bi) = A;
263+
}
264+
265+
template <CMatrix TMatrix, CMatrix TIndexMatrix>
266+
PBAT_HOST_DEVICE void
267+
ToFlatBuffer(TMatrix const& A, TIndexMatrix const& inds, typename TMatrix::ScalarType* buf)
268+
{
269+
auto constexpr MA = TMatrix::kRows;
270+
auto constexpr NA = TMatrix::kCols;
271+
auto constexpr MI = TIndexMatrix::kRows;
272+
auto constexpr NI = TIndexMatrix::kCols;
273+
static_assert(MA == MI or MI == 1, "A must have same rows as inds or inds is a row vector");
274+
static_assert(NA == NI, "A must have same cols as inds");
275+
using pbat::common::ForRange;
276+
if constexpr (MA > 1 and MI == 1)
277+
{
278+
// In this case, I will assume that the user wishes to put each column of A in the
279+
// corresponding "column" in the flat buffer buf, as if column major, according to inds.
280+
ForRange<0, NA>([&]<auto j>() {
281+
ForRange<0, MA>([&]<auto i>() { buf[MA * inds(0, j) + i] = A(i, j); });
282+
});
283+
}
284+
else
285+
{
286+
ForRange<0, NA>(
287+
[&]<auto j>() { ForRange<0, MA>([&]<auto i>() { buf[inds(i, j)] = A(i, j); }); });
288+
}
289+
}
290+
291+
template <int M, int N, class TScalar, class IndexType>
292+
PBAT_HOST_DEVICE auto
293+
FromBuffers([[maybe_unused]] std::array<TScalar*, M> buf, [[maybe_unused]] IndexType bi)
294+
{
295+
using ScalarType = std::remove_const_t<TScalar>;
296+
SMatrix<ScalarType, M, N> A{};
297+
using pbat::common::ForRange;
298+
ForRange<0, M>([&]<auto i>() { A.Row(i) = FromFlatBuffer<1, N>(buf[i], bi); });
299+
return A;
300+
}
301+
302+
template <int K, class TScalar, CMatrix TIndexMatrix>
303+
PBAT_HOST_DEVICE auto FromBuffers(std::array<TScalar*, K> buf, TIndexMatrix const& inds)
304+
{
305+
using IntegerType = typename TIndexMatrix::ScalarType;
306+
static_assert(std::is_integral_v<IntegerType>, "inds must be matrix of indices");
307+
auto constexpr M = TIndexMatrix::kRows;
308+
auto constexpr N = TIndexMatrix::kCols;
309+
SMatrix<std::remove_const_t<TScalar>, K * M, N> A{};
310+
using pbat::common::ForRange;
311+
ForRange<0, K>([&]<auto k>() { A.Slice<M, N>(k * M, 0) = FromFlatBuffer(buf[k], inds); });
312+
return A;
313+
}
314+
315+
template <CMatrix TMatrix, int M, class IndexType>
316+
PBAT_HOST_DEVICE void
317+
ToBuffers(TMatrix const& A, std::array<typename TMatrix::ScalarType*, M> buf, IndexType bi)
318+
{
319+
static_assert(M == TMatrix::kRows, "A must have same rows as number of buffers");
320+
auto constexpr N = TMatrix::kCols;
321+
using pbat::common::ForRange;
322+
ForRange<0, M>([&]<auto i>() { FromFlatBuffer<1, N>(buf[i], bi) = A.Row(i); });
323+
}
324+
325+
template <CMatrix TMatrix, CMatrix TIndexMatrix, int K>
326+
PBAT_HOST_DEVICE void ToBuffers(
327+
TMatrix const& A,
328+
TIndexMatrix const& inds,
329+
std::array<typename TMatrix::ScalarType*, K> buf)
330+
{
331+
auto constexpr MA = TMatrix::kRows;
332+
auto constexpr NA = TMatrix::kCols;
333+
auto constexpr MI = TIndexMatrix::kRows;
334+
auto constexpr NI = TIndexMatrix::kCols;
335+
static_assert(MA % MI == 0, "Rows of A must be multiple of rows of inds");
336+
static_assert(NA == NI, "A and inds must have same number of columns");
337+
static_assert(MA / MI == K, "A must have number of rows == #buffers*#rows of inds");
338+
using pbat::common::ForRange;
339+
ForRange<0, K>([&]<auto k>() { ToFlatBuffer(A.Slice<MI, NI>(k * MI, 0), inds, buf[k]); });
340+
}
341+
236342
} // namespace mini
237343
} // namespace linalg
238344
} // namespace math

source/pbat/math/linalg/mini/Mini.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "Product.h"
1414
#include "Reductions.h"
1515
#include "Repeat.h"
16+
#include "Scale.h"
17+
#include "Stack.h"
1618
#include "SubMatrix.h"
1719
#include "Transpose.h"
1820
#include "UnaryOperations.h"
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#include "Stack.h"
2+
3+
#include "Matrix.h"
4+
5+
#include <doctest/doctest.h>
6+
7+
TEST_CASE("[math][linalg][mini] Stack")
8+
{
9+
using namespace pbat::math::linalg::mini;
10+
auto constexpr kRows = 2;
11+
auto constexpr kCols = 3;
12+
using ScalarType = double;
13+
using MatrixType = SMatrix<ScalarType, kRows, kCols>;
14+
MatrixType A{};
15+
A.SetConstant(ScalarType(1));
16+
MatrixType B{};
17+
B.SetConstant(ScalarType(2));
18+
auto hstack = HStack(A, B);
19+
CHECK_EQ(hstack.Rows(), A.Rows());
20+
CHECK_EQ(hstack.Cols(), A.Cols() + B.Cols());
21+
for (auto i = 0; i < kRows; ++i)
22+
{
23+
for (auto j = 0; j < kCols; ++j)
24+
{
25+
CHECK_EQ(hstack(i, j), ScalarType(1));
26+
}
27+
for (auto j = kCols; j < 2 * kCols; ++j)
28+
{
29+
CHECK_EQ(hstack(i, j), ScalarType(2));
30+
}
31+
}
32+
auto vstack = VStack(A, B);
33+
CHECK_EQ(vstack.Cols(), A.Cols());
34+
CHECK_EQ(vstack.Rows(), A.Rows() + B.Rows());
35+
for (auto j = 0; j < kCols; ++j)
36+
{
37+
for (auto i = 0; i < kRows; ++i)
38+
{
39+
CHECK_EQ(vstack(i, j), ScalarType(1));
40+
}
41+
for (auto i = kRows; i < 2 * kRows; ++i)
42+
{
43+
CHECK_EQ(vstack(i, j), ScalarType(2));
44+
}
45+
}
46+
}

0 commit comments

Comments
 (0)