Skip to content

Commit 917fe3a

Browse files
committed
Use integer storage in SubMatrix types
1 parent de3adbb commit 917fe3a

File tree

3 files changed

+24
-25
lines changed

3 files changed

+24
-25
lines changed

source/pbat/math/linalg/mini/Matrix.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ PBAT_HOST_DEVICE auto Unit(auto i)
235235
return Identity<TScalar, M, M>().Col(i);
236236
}
237237

238-
template <int M, int N, class TScalar>
238+
template <auto M, auto N, class TScalar>
239239
PBAT_HOST_DEVICE auto FromFlatBuffer(TScalar* buf, std::int64_t bi)
240240
{
241241
return SMatrixView<TScalar, M, N>(buf + M * N * bi);
@@ -290,7 +290,7 @@ ToFlatBuffer(TMatrix const& A, TIndexMatrix const& inds, typename TMatrix::Scala
290290
}
291291
}
292292

293-
template <int M, int N, class TScalar>
293+
template <auto M, auto N, class TScalar>
294294
PBAT_HOST_DEVICE auto FromBuffers(std::array<TScalar*, M> buf, std::int64_t bi)
295295
{
296296
using ScalarType = std::remove_const_t<TScalar>;
@@ -300,7 +300,7 @@ PBAT_HOST_DEVICE auto FromBuffers(std::array<TScalar*, M> buf, std::int64_t bi)
300300
return A;
301301
}
302302

303-
template <int K, class TScalar, CMatrix TIndexMatrix>
303+
template <auto K, class TScalar, CMatrix TIndexMatrix>
304304
PBAT_HOST_DEVICE auto FromBuffers(std::array<TScalar*, K> buf, TIndexMatrix const& inds)
305305
{
306306
using IntegerType = typename TIndexMatrix::ScalarType;
@@ -315,7 +315,7 @@ PBAT_HOST_DEVICE auto FromBuffers(std::array<TScalar*, K> buf, TIndexMatrix cons
315315
return A;
316316
}
317317

318-
template <CMatrix TMatrix, int M>
318+
template <CMatrix TMatrix, auto M>
319319
PBAT_HOST_DEVICE void
320320
ToBuffers(TMatrix const& A, std::array<typename TMatrix::ScalarType*, M> buf, std::int64_t bi)
321321
{
@@ -325,7 +325,7 @@ ToBuffers(TMatrix const& A, std::array<typename TMatrix::ScalarType*, M> buf, st
325325
ForRange<0, M>([&]<auto i>() { FromFlatBuffer<1, N>(buf[i], bi) = A.Row(i); });
326326
}
327327

328-
template <CMatrix TMatrix, CMatrix TIndexMatrix, int K>
328+
template <CMatrix TMatrix, CMatrix TIndexMatrix, auto K>
329329
PBAT_HOST_DEVICE void ToBuffers(
330330
TMatrix const& A,
331331
TIndexMatrix const& inds,

source/pbat/math/linalg/mini/SubMatrix.h

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ class ConstSubMatrix
2626
static auto constexpr kCols = N;
2727
static bool constexpr bRowMajor = NestedType::bRowMajor;
2828

29-
PBAT_HOST_DEVICE ConstSubMatrix(NestedType const& A, auto ib = 0, auto jb = 0)
30-
: A(A), ib(ib), jb(jb)
29+
PBAT_HOST_DEVICE ConstSubMatrix(NestedType const& A, int ib, int jb) : A(A), ib(ib), jb(jb)
3130
{
3231
static_assert(
3332
NestedType::kRows >= M and NestedType::kCols >= N and M > 0 and N > 0,
@@ -78,7 +77,7 @@ class SubMatrix
7877
static auto constexpr kCols = N;
7978
static bool constexpr bRowMajor = NestedType::bRowMajor;
8079

81-
PBAT_HOST_DEVICE SubMatrix(NestedType& A, auto ib = 0, auto jb = 0) : A(A), ib(ib), jb(jb)
80+
PBAT_HOST_DEVICE SubMatrix(NestedType& A, int ib, int jb) : A(A), ib(ib), jb(jb)
8281
{
8382
static_assert(
8483
NestedType::kRows >= M and NestedType::kCols >= N and M > 0 and N > 0,
@@ -132,30 +131,30 @@ class SubMatrix
132131

133132
#define PBAT_MINI_SUBMATRIX_API(SelfType) \
134133
template <auto S, auto T> \
135-
PBAT_HOST_DEVICE auto Slice(auto i, auto j) \
134+
PBAT_HOST_DEVICE auto Slice(int i, int j) \
136135
{ \
137136
return SubMatrix<SelfType, S, T>(*this, i, j); \
138137
} \
139-
PBAT_HOST_DEVICE auto Col(auto j) \
138+
PBAT_HOST_DEVICE auto Col(int j) \
140139
{ \
141140
return Slice<kRows, 1>(0, j); \
142141
} \
143-
PBAT_HOST_DEVICE auto Row(auto i) \
142+
PBAT_HOST_DEVICE auto Row(int i) \
144143
{ \
145144
return Slice<1, kCols>(i, 0); \
146145
}
147146

148147
#define PBAT_MINI_CONST_SUBMATRIX_API(SelfType) \
149148
template <auto S, auto T> \
150-
PBAT_HOST_DEVICE auto Slice(auto i, auto j) const \
149+
PBAT_HOST_DEVICE auto Slice(int i, int j) const \
151150
{ \
152151
return ConstSubMatrix<SelfType, S, T>(*this, i, j); \
153152
} \
154-
PBAT_HOST_DEVICE auto Col(auto j) const \
153+
PBAT_HOST_DEVICE auto Col(int j) const \
155154
{ \
156155
return Slice<kRows, 1>(0, j); \
157156
} \
158-
PBAT_HOST_DEVICE auto Row(auto i) const \
157+
PBAT_HOST_DEVICE auto Row(int i) const \
159158
{ \
160159
return Slice<1, kCols>(i, 0); \
161160
}

source/pbat/math/linalg/mini/Transpose.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class ConstTransposeSubMatrix
3838
static auto constexpr kCols = N;
3939
static bool constexpr bRowMajor = NestedType::bRowMajor;
4040

41-
PBAT_HOST_DEVICE ConstTransposeSubMatrix(NestedType const& A, auto ib = 0, auto jb = 0)
41+
PBAT_HOST_DEVICE ConstTransposeSubMatrix(NestedType const& A, int ib = 0, int jb = 0)
4242
: A(A), ib(ib), jb(jb)
4343
{
4444
static_assert(
@@ -56,12 +56,12 @@ class ConstTransposeSubMatrix
5656
PBAT_HOST_DEVICE ScalarType operator[](auto i) const { return (*this)(i); }
5757

5858
template <auto S, auto T>
59-
PBAT_HOST_DEVICE auto Slice(auto i, auto j) const
59+
PBAT_HOST_DEVICE auto Slice(int i, int j) const
6060
{
6161
return ConstTransposeSubMatrix<SelfType, S, T>(*this, i, j);
6262
}
63-
PBAT_HOST_DEVICE auto Col(auto j) const { return Slice<kRows, 1>(0, j); }
64-
PBAT_HOST_DEVICE auto Row(auto i) const { return Slice<1, kCols>(i, 0); }
63+
PBAT_HOST_DEVICE auto Col(int j) const { return Slice<kRows, 1>(0, j); }
64+
PBAT_HOST_DEVICE auto Row(int i) const { return Slice<1, kCols>(i, 0); }
6565
PBAT_HOST_DEVICE auto Transpose() const { return ConstTransposeView<SelfType>(*this); }
6666

6767
private:
@@ -81,7 +81,7 @@ class TransposeSubMatrix
8181
static auto constexpr kCols = N;
8282
static bool constexpr bRowMajor = NestedType::bRowMajor;
8383

84-
PBAT_HOST_DEVICE TransposeSubMatrix(NestedType& A, auto ib = 0, auto jb = 0)
84+
PBAT_HOST_DEVICE TransposeSubMatrix(NestedType& A, int ib = 0, int jb = 0)
8585
: A(A), ib(ib), jb(jb)
8686
{
8787
static_assert(
@@ -109,20 +109,20 @@ class TransposeSubMatrix
109109
PBAT_HOST_DEVICE ScalarType& operator[](auto i) { return (*this)(i); }
110110

111111
template <auto S, auto T>
112-
PBAT_HOST_DEVICE auto Slice(auto i, auto j)
112+
PBAT_HOST_DEVICE auto Slice(int i, int j)
113113
{
114114
return TransposeSubMatrix<SelfType, S, T>(*this, i, j);
115115
}
116-
PBAT_HOST_DEVICE auto Col(auto j) { return Slice<kRows, 1>(0, j); }
117-
PBAT_HOST_DEVICE auto Row(auto i) { return Slice<1, kCols>(i, 0); }
116+
PBAT_HOST_DEVICE auto Col(int j) { return Slice<kRows, 1>(0, j); }
117+
PBAT_HOST_DEVICE auto Row(int i) { return Slice<1, kCols>(i, 0); }
118118

119119
template <auto S, auto T>
120-
PBAT_HOST_DEVICE auto Slice(auto i, auto j) const
120+
PBAT_HOST_DEVICE auto Slice(int i, int j) const
121121
{
122122
return ConstTransposeSubMatrix<SelfType, S, T>(*this, i, j);
123123
}
124-
PBAT_HOST_DEVICE auto Col(auto j) const { return Slice<kRows, 1>(0, j); }
125-
PBAT_HOST_DEVICE auto Row(auto i) const { return Slice<1, kCols>(i, 0); }
124+
PBAT_HOST_DEVICE auto Col(int j) const { return Slice<kRows, 1>(0, j); }
125+
PBAT_HOST_DEVICE auto Row(int i) const { return Slice<1, kCols>(i, 0); }
126126

127127
PBAT_HOST_DEVICE auto Transpose() { return TransposeView<SelfType>(*this); }
128128
PBAT_HOST_DEVICE auto Transpose() const { return ConstTransposeView<SelfType>(*this); }

0 commit comments

Comments
 (0)