Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions llvm/include/llvm/ADT/Matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,25 @@ class MatrixStorage {

size_t size() const { return Base.size(); }
bool empty() const { return !size(); }
size_t getNumRows() const { return size() / NCols; }
size_t getNumRows() const {
assert(size() % NCols == 0 && "Internal error");
return size() / NCols;
}
size_t getNumCols() const { return NCols; }
void setNumCols(size_t NCols) {
assert(empty() && "Column-resizing a non-empty MatrixStorage");
this->NCols = NCols;
}
void resize(size_t NRows) { Base.resize(NCols * NRows); }
void reserve(size_t NRows) { Base.reserve(NCols * NRows); }

protected:
template <typename U, size_t M, size_t NStorageInline>
friend class JaggedArrayView;

T *begin() const { return Base.begin(); }
T *rowFromIdx(size_t RowIdx, size_t Offset = 0) const {
assert(Offset < NCols && "Internal error");
return begin() + RowIdx * NCols + Offset;
}
std::pair<size_t, size_t> idxFromRow(T *Ptr) const {
Expand All @@ -87,6 +92,11 @@ class MatrixStorage {
Base.append(Diff, T());
}

void eraseLastRow() {
assert(getNumRows() > 0 && "Non-empty MatrixStorage expected");
Base.pop_back_n(NCols);
}

private:
MatrixStorageBase<T, N> Base;
size_t NCols;
Expand All @@ -108,10 +118,12 @@ struct [[nodiscard]] MutableRowView : public MutableArrayRef<T> {
: MutableArrayRef<T>(Begin, End) {}
MutableRowView(MutableArrayRef<T> Other)
: MutableArrayRef<T>(Other.data(), Other.size()) {}
MutableRowView(const SmallVectorImpl<T> &Vec) : MutableArrayRef<T>(Vec) {}
MutableRowView(SmallVectorImpl<T> &Vec) : MutableArrayRef<T>(Vec) {}

using MutableArrayRef<T>::size;
using MutableArrayRef<T>::data;
using MutableArrayRef<T>::begin;
using MutableArrayRef<T>::end;

T &back() const { return MutableArrayRef<T>::back(); }
T &front() const { return MutableArrayRef<T>::front(); }
Expand Down Expand Up @@ -148,6 +160,13 @@ struct [[nodiscard]] MutableRowView : public MutableArrayRef<T> {
std::swap(this->Length, Other.Length);
}

// For better cache behavior.
void writing_swap(MutableRowView<T> &Other) { // NOLINT
SmallVector<T> Buf{Other};
Other.copy_assign(begin(), end());
copy_assign(Buf.begin(), Buf.end());
}

protected:
void copy_assign(iterator Begin, iterator End) { // NOLINT
std::uninitialized_copy(Begin, End, data());
Expand Down Expand Up @@ -317,6 +336,14 @@ class [[nodiscard]] JaggedArrayView {
RowView.pop_back();
}

// For better cache behavior. To be used with writing_swap.
void eraseLastRow() {
assert(Mat.idxFromRow(lastRow().data()).first == Mat.getNumRows() - 1 &&
"Last row does not correspond to last row in storage");
dropLastRow();
Mat.eraseLastRow();
}

protected:
// Helper constructor.
constexpr JaggedArrayView(MatrixStorage<T, NStorageInline> &Mat,
Expand All @@ -329,11 +356,4 @@ class [[nodiscard]] JaggedArrayView {
};
} // namespace llvm

namespace std {
template <typename T>
inline void swap(llvm::MutableRowView<T> &LHS, llvm::MutableRowView<T> &RHS) {
LHS.swap(RHS);
}
} // end namespace std

#endif
58 changes: 37 additions & 21 deletions llvm/unittests/ADT/MatrixTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,31 +161,12 @@ TYPED_TEST(MatrixTest, RowColSlice) {
EXPECT_EQ(W[1][1], 0);
}

TYPED_TEST(MatrixTest, LastRowOps) {
TYPED_TEST(MatrixTest, NonWritingSwap) {
auto &M = this->SmallMatrix;
JaggedArrayView<TypeParam> V{M};
V[0] = {TypeParam(3), TypeParam(7)};
V[1] = {TypeParam(4), TypeParam(5)};
V.dropLastRow();
ASSERT_EQ(M.size(), 4u);
ASSERT_EQ(V.size(), 2u);
auto W = V.lastRow();
ASSERT_EQ(W.size(), 2u);
EXPECT_EQ(W[0], 3);
EXPECT_EQ(W[1], 7);
V.lastRow() = {TypeParam(1), TypeParam(2)};
EXPECT_EQ(V.lastRow()[0], 1);
EXPECT_EQ(V.lastRow()[1], 2);
V.dropLastRow();
EXPECT_TRUE(V.empty());
}

TYPED_TEST(MatrixTest, Swap) {
auto &M = this->SmallMatrix;
JaggedArrayView<TypeParam> V{M};
V[0] = {TypeParam(3), TypeParam(7)};
V[1] = {TypeParam(4), TypeParam(5)};
std::swap(V[0], V[1]);
V[0].swap(V[1]);
EXPECT_EQ(V.lastRow()[0], 3);
EXPECT_EQ(V.lastRow()[1], 7);
EXPECT_EQ(V[0][0], 4);
Expand Down Expand Up @@ -241,6 +222,41 @@ TYPED_TEST(MatrixTest, DropLastRow) {
EXPECT_EQ(V.lastRow()[2], 23);
}

TYPED_TEST(MatrixTest, EraseLastRow) {
auto &M = this->SmallMatrix;
JaggedArrayView<TypeParam> V{M};
V[0] = {TypeParam(3), TypeParam(7)};
V[1] = {TypeParam(4), TypeParam(5)};
V.eraseLastRow();
ASSERT_EQ(M.size(), 2u);
ASSERT_EQ(V.getRowSpan(), 1u);
auto W = V.lastRow();
ASSERT_EQ(W.size(), 2u);
EXPECT_EQ(W[0], 3);
EXPECT_EQ(W[1], 7);
V.addRow({TypeParam(1), TypeParam(2)});
ASSERT_EQ(V.getRowSpan(), 2u);
V[0].writing_swap(V[1]);
EXPECT_EQ(V[0][0], 1);
EXPECT_EQ(V[0][1], 2);
EXPECT_EQ(V.lastRow()[0], 3);
EXPECT_EQ(V.lastRow()[1], 7);
V.eraseLastRow();
V.addRow({TypeParam(3), TypeParam(7)});
ASSERT_EQ(V.getRowSpan(), 2u);
EXPECT_EQ(V[0][0], 1);
EXPECT_EQ(V[0][1], 2);
EXPECT_EQ(V[1][0], 3);
EXPECT_EQ(V[1][1], 7);
V.eraseLastRow();
ASSERT_EQ(V.getRowSpan(), 1u);
EXPECT_EQ(V[0][0], 1);
EXPECT_EQ(V[0][1], 2);
V.eraseLastRow();
EXPECT_TRUE(V.empty());
EXPECT_TRUE(M.empty());
}

TYPED_TEST(MatrixTest, Iteration) {
auto &M = this->SmallMatrix;
M.resize(2);
Expand Down