|
6 | 6 | #define EIGEN_MPL2_ONLY
|
7 | 7 | #include <Eigen/Dense>
|
8 | 8 | #include "Common.h"
|
| 9 | +#include "Vector.h" |
9 | 10 |
|
10 | 11 | namespace isce3 { namespace core {
|
11 | 12 |
|
12 | 13 | template<int N, typename T>
|
13 | 14 | class DenseMatrix : public Eigen::Matrix<T, N, N> {
|
14 | 15 | using super_t = Eigen::Matrix<T, N, N>;
|
15 | 16 | using super_t::super_t;
|
| 17 | + |
| 18 | + static_assert(N > 0); |
16 | 19 | public:
|
17 | 20 | DenseMatrix() = default;
|
18 | 21 | CUDA_HOSTDEV auto operator[](int i) { return this->row(i); }
|
19 | 22 | CUDA_HOSTDEV auto operator[](int i) const { return this->row(i); }
|
20 | 23 |
|
21 |
| - CUDA_HOSTDEV auto dot(const super_t& other) const { |
| 24 | + CUDA_HOSTDEV auto dot(const DenseMatrix& other) const { |
22 | 25 | return *this * other;
|
23 | 26 | }
|
24 | 27 |
|
25 |
| - CUDA_HOSTDEV auto dot(const Eigen::Matrix<T, N, 1>& other) const { |
| 28 | + CUDA_HOSTDEV auto dot(const Vector<N, T>& other) const { |
26 | 29 | return *this * other;
|
27 | 30 | }
|
28 | 31 |
|
@@ -78,4 +81,33 @@ CUDA_HOSTDEV Mat3 DenseMatrix<N, T>::enuToXyz(double lat, double lon)
|
78 | 81 | {0, cos(lat), sin(lat)}}};
|
79 | 82 | }
|
80 | 83 |
|
| 84 | +// XXX |
| 85 | +// These overloads are a workaround to resolve an issue observed with certain |
| 86 | +// Eigen & CUDA version combinations where matrix-matrix and matrix-vector |
| 87 | +// multiplication produced incorrect results (or raised "illegal memory access" |
| 88 | +// errors in debug mode). |
| 89 | +template<int N, typename T> |
| 90 | +CUDA_HOSTDEV auto |
| 91 | +operator*(const DenseMatrix<N, T>& a, const DenseMatrix<N, T>& b) |
| 92 | +{ |
| 93 | + DenseMatrix<N, T> out; |
| 94 | + for (int i = 0; i < N; ++i) { |
| 95 | + for (int j = 0; j < N; ++j) { |
| 96 | + out(i, j) = a.row(i).dot(b.col(j)); |
| 97 | + } |
| 98 | + } |
| 99 | + return out; |
| 100 | +} |
| 101 | + |
| 102 | +template<int N, typename T> |
| 103 | +CUDA_HOSTDEV auto |
| 104 | +operator*(const DenseMatrix<N, T>& m, const Vector<N, T>& v) |
| 105 | +{ |
| 106 | + Vector<N, T> out; |
| 107 | + for (int i = 0; i < N; ++i) { |
| 108 | + out[i] = m.row(i).dot(v); |
| 109 | + } |
| 110 | + return out; |
| 111 | +} |
| 112 | + |
81 | 113 | }}
|
0 commit comments