Skip to content

Commit b6d145c

Browse files
committed
Unified Validator and Visitor
1 parent deb3686 commit b6d145c

File tree

9 files changed

+587
-300
lines changed

9 files changed

+587
-300
lines changed

cpp/src/arrow/sparse_tensor.cc

Lines changed: 357 additions & 21 deletions
Large diffs are not rendered by default.

cpp/src/arrow/sparse_tensor.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,10 @@ class ARROW_EXPORT SparseTensor {
508508
return ToTensor(default_memory_pool());
509509
}
510510

511+
/// \brief Check whether the sparse tensor is valid and is the
512+
/// correct compressed form of the given tensor.
513+
Status Validate(const Tensor& tensor) const;
514+
511515
protected:
512516
// Constructor with all attributes
513517
SparseTensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
@@ -588,6 +592,8 @@ class SparseTensorImpl : public SparseTensor {
588592
ARROW_RETURN_NOT_OK(internal::MakeSparseTensorFromTensor(
589593
tensor, SparseIndexType::format_id, index_value_type, pool, &sparse_index,
590594
&data));
595+
// TODO CHECK SparseTensorCreation.
596+
591597
return std::make_shared<SparseTensorImpl<SparseIndexType>>(
592598
internal::checked_pointer_cast<SparseIndexType>(sparse_index), tensor.type(),
593599
data, tensor.shape(), tensor.dim_names_);

cpp/src/arrow/sparse_tensor_test.cc

Lines changed: 108 additions & 5 deletions
Large diffs are not rendered by default.

cpp/src/arrow/tensor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ int64_t StridedTensorCountNonZero(int dim_index, int64_t offset, const Tensor& t
490490
if (dim_index == tensor.ndim() - 1) {
491491
for (int64_t i = 0; i < tensor.shape()[dim_index]; ++i) {
492492
const auto* ptr = tensor.raw_data() + offset + i * tensor.strides()[dim_index];
493-
auto& elem = *reinterpret_cast<const c_type*>(ptr);
493+
auto elem = *reinterpret_cast<const c_type*>(ptr);
494494
if (internal::is_not_zero<TYPE>(elem)) {
495495
++nnz;
496496
}

cpp/src/arrow/tensor/converter_internal.h

Lines changed: 4 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424

2525
namespace arrow {
2626

27-
template <typename VISITOR, typename... ARGS>
28-
Status VisitTypeInline(const DataType& type, VISITOR* visitor, ARGS&&... args);
29-
3027
namespace internal {
3128

3229
struct SparseTensorConverterMixin {
@@ -71,52 +68,14 @@ Result<std::shared_ptr<Tensor>> MakeTensorFromSparseCSFTensor(
7168
template <typename Converter>
7269
struct ConverterVisitor {
7370
explicit ConverterVisitor(Converter& converter) : converter(converter) {}
74-
template <typename ValueType, typename IndexType>
75-
Status operator()(const ValueType& value, const IndexType& index_type) {
76-
return converter.Convert(value, index_type);
77-
}
7871

79-
Converter& converter;
80-
};
81-
82-
struct ValueTypeVisitor {
83-
template <typename ValueType, typename IndexType, typename Function>
84-
enable_if_number<ValueType, Status> Visit(const ValueType& value_type,
85-
const IndexType& index_type,
86-
Function&& function) {
87-
return function(value_type, index_type);
72+
template <typename... Args>
73+
Status operator()(Args&&... args) {
74+
return converter.Convert(std::forward<Args>(args)...);
8875
}
8976

90-
template <typename IndexType, typename Function>
91-
Status Visit(const DataType& value_type, const IndexType&, Function&&) {
92-
return Status::Invalid("Invalid value type: ", value_type.name(),
93-
". Expected a number.");
94-
}
95-
};
96-
97-
struct IndexAndValueTypeVisitor {
98-
template <typename IndexType, typename Function>
99-
enable_if_integer<IndexType, Status> Visit(const IndexType& index_type,
100-
const DataType& value_type,
101-
Function&& function) {
102-
ValueTypeVisitor visitor;
103-
return VisitTypeInline(value_type, &visitor, index_type,
104-
std::forward<Function>(function));
105-
}
106-
107-
template <typename Function>
108-
Status Visit(const DataType& type, const DataType&, Function&&) {
109-
return Status::Invalid("Invalid index type: ", type.name(), ". Expected integer.");
110-
}
77+
Converter& converter;
11178
};
11279

113-
template <typename Function>
114-
Status VisitValueAndIndexType(const DataType& value_type, const DataType& index_type,
115-
Function&& function) {
116-
IndexAndValueTypeVisitor visitor;
117-
return VisitTypeInline(index_type, &visitor, value_type,
118-
std::forward<Function>(function));
119-
}
120-
12180
} // namespace internal
12281
} // namespace arrow

cpp/src/arrow/tensor/coo_converter.cc

Lines changed: 3 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "arrow/tensor/converter_internal.h"
1919

2020
#include <algorithm>
21-
#include <cmath>
2221
#include <cstdint>
2322
#include <memory>
2423
#include <numeric>
@@ -28,11 +27,10 @@
2827
#include "arrow/status.h"
2928
#include "arrow/tensor.h"
3029
#include "arrow/type.h"
31-
#include "arrow/type_traits.h"
3230
#include "arrow/util/checked_cast.h"
3331
#include "arrow/util/logging_internal.h"
3432
#include "arrow/util/macros.h"
35-
#include "arrow/visit_type_inline.h"
33+
#include "arrow/util/sparse_tensor_util.h"
3634

3735
namespace arrow {
3836

@@ -42,57 +40,6 @@ namespace internal {
4240

4341
namespace {
4442

45-
template <typename ValueType, typename IndexType>
46-
Status ValidateSparseCooTensorCreation(const SparseCOOIndex& sparse_coo_index,
47-
const Buffer& sparse_coo_values_buffer,
48-
const Tensor& tensor) {
49-
using IndexCType = typename IndexType::c_type;
50-
using ValueCType = typename ValueType::c_type;
51-
52-
const auto& indices = sparse_coo_index.indices();
53-
const auto* indices_data = sparse_coo_index.indices()->data()->data_as<IndexCType>();
54-
const auto* sparse_coo_values = sparse_coo_values_buffer.data_as<ValueCType>();
55-
56-
ARROW_ASSIGN_OR_RAISE(auto non_zero_count, tensor.CountNonZero());
57-
58-
if (indices->shape()[0] != non_zero_count) {
59-
return Status::Invalid("Mismatch between non-zero count in sparse tensor (",
60-
indices->shape()[0], ") and dense tensor (", non_zero_count,
61-
")");
62-
} else if (indices->shape()[1] != static_cast<int64_t>(tensor.shape().size())) {
63-
return Status::Invalid("Mismatch between coordinate dimension in sparse tensor (",
64-
indices->shape()[1], ") and tensor shape (",
65-
tensor.shape().size(), ")");
66-
}
67-
68-
auto coord_size = indices->shape()[1];
69-
std::vector<int64_t> coord(coord_size);
70-
for (int64_t i = 0; i < indices->shape()[0]; i++) {
71-
if (!is_not_zero<ValueType>(sparse_coo_values[i])) {
72-
return Status::Invalid("Sparse tensor values must be non-zero");
73-
}
74-
75-
for (int64_t j = 0; j < coord_size; j++) {
76-
coord[j] = static_cast<int64_t>(indices_data[i * coord_size + j]);
77-
}
78-
79-
if (sparse_coo_values[i] != tensor.Value<ValueType>(coord)) {
80-
if constexpr (is_floating_type<ValueType>::value) {
81-
if (!std::isnan(tensor.Value<ValueType>(coord)) ||
82-
!std::isnan(sparse_coo_values[i])) {
83-
return Status::Invalid(
84-
"Inconsistent values between sparse tensor and dense tensor");
85-
}
86-
} else {
87-
return Status::Invalid(
88-
"Inconsistent values between sparse tensor and dense tensor");
89-
}
90-
}
91-
}
92-
93-
return Status::OK();
94-
}
95-
9643
template <typename IndexCType>
9744
inline void IncrementRowMajorIndex(std::vector<IndexCType>& coord,
9845
const std::vector<int64_t>& shape) {
@@ -265,8 +212,6 @@ class SparseCOOTensorConverter {
265212
indices_shape, indices_strides);
266213
ARROW_ASSIGN_OR_RAISE(sparse_index, SparseCOOIndex::Make(coords, true));
267214
data = std::move(values_buffer);
268-
DCHECK_OK((ValidateSparseCooTensorCreation<ValueType, IndexType>(*sparse_index, *data,
269-
tensor_)));
270215
return Status::OK();
271216
}
272217

@@ -328,7 +273,8 @@ Status MakeSparseCOOTensorFromTensor(const Tensor& tensor,
328273
std::shared_ptr<Buffer>* out_data) {
329274
SparseCOOTensorConverter converter(tensor, index_value_type, pool);
330275
ConverterVisitor visitor{converter};
331-
ARROW_RETURN_NOT_OK(VisitValueAndIndexType(*tensor.type(), *index_value_type, visitor));
276+
ARROW_RETURN_NOT_OK(
277+
util::VisitCOOTensorType(*tensor.type(), *index_value_type, visitor));
332278
*out_sparse_index = checked_pointer_cast<SparseIndex>(converter.sparse_index);
333279
*out_data = converter.data;
334280
return Status::OK();

cpp/src/arrow/tensor/csf_converter.cc

Lines changed: 7 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#include "arrow/tensor/converter_internal.h"
1919

2020
#include <algorithm>
21-
#include <cmath>
2221
#include <cstdint>
2322
#include <limits>
2423
#include <memory>
@@ -30,11 +29,10 @@
3029
#include "arrow/status.h"
3130
#include "arrow/tensor.h"
3231
#include "arrow/type.h"
33-
#include "arrow/type_traits.h"
3432
#include "arrow/util/checked_cast.h"
3533
#include "arrow/util/logging_internal.h"
3634
#include "arrow/util/sort_internal.h"
37-
#include "arrow/visit_type_inline.h"
35+
#include "arrow/util/sparse_tensor_util.h"
3836

3937
namespace arrow {
4038

@@ -58,89 +56,6 @@ inline void IncrementIndex(std::vector<int64_t>& coord, const std::vector<int64_
5856
}
5957
}
6058

61-
template <typename ValueType, typename IndexType>
62-
Status CheckValues(const SparseCSFIndex& sparse_csf_index,
63-
const typename ValueType::c_type* values, const Tensor& tensor,
64-
const int64_t dim, const int64_t dim_offset, const int64_t start,
65-
const int64_t stop) {
66-
using ValueCType = typename ValueType::c_type;
67-
using IndexCType = typename IndexType::c_type;
68-
69-
const auto& indices = sparse_csf_index.indices();
70-
const auto& indptr = sparse_csf_index.indptr();
71-
const auto& axis_order = sparse_csf_index.axis_order();
72-
auto ndim = indices.size();
73-
auto strides = tensor.strides();
74-
75-
const auto& cur_indices = indices[dim];
76-
const auto* indices_data = cur_indices->data()->data_as<IndexCType>() + start;
77-
78-
if (dim == static_cast<int64_t>(ndim) - 1) {
79-
for (auto i = start; i < stop; ++i) {
80-
auto index = static_cast<int64_t>(*indices_data);
81-
const int64_t offset = dim_offset + index * strides[axis_order[dim]];
82-
83-
auto sparse_value = values[i];
84-
auto tensor_value =
85-
*reinterpret_cast<const ValueCType*>(tensor.raw_data() + offset);
86-
if (!is_not_zero<ValueType>(sparse_value)) {
87-
return Status::Invalid("Sparse tensor values must be non-zero");
88-
} else if (sparse_value != tensor_value) {
89-
if constexpr (is_floating_type<ValueType>::value) {
90-
if (!std::isnan(tensor_value) || !std::isnan(sparse_value)) {
91-
return Status::Invalid(
92-
"Inconsistent values between sparse tensor and dense tensor");
93-
}
94-
} else {
95-
return Status::Invalid(
96-
"Inconsistent values between sparse tensor and dense tensor");
97-
}
98-
}
99-
++indices_data;
100-
}
101-
} else {
102-
const auto& cur_indptr = indptr[dim];
103-
const auto* indptr_data = cur_indptr->data()->data_as<IndexCType>() + start;
104-
105-
for (int64_t i = start; i < stop; ++i) {
106-
const int64_t index = *indices_data;
107-
int64_t offset = dim_offset + index * strides[axis_order[dim]];
108-
auto next_start = static_cast<int64_t>(*indptr_data);
109-
auto next_stop = static_cast<int64_t>(*(indptr_data + 1));
110-
111-
ARROW_RETURN_NOT_OK((CheckValues<ValueType, IndexType>(
112-
sparse_csf_index, values, tensor, dim + 1, offset, next_start, next_stop)));
113-
114-
++indices_data;
115-
++indptr_data;
116-
}
117-
}
118-
return Status::OK();
119-
}
120-
121-
template <typename ValueType, typename IndexType>
122-
Status ValidateSparseTensorCSFCreation(const SparseIndex& sparse_index,
123-
const Buffer& values_buffer,
124-
const Tensor& tensor) {
125-
auto sparse_csf_index = checked_cast<const SparseCSFIndex&>(sparse_index);
126-
const auto* values = values_buffer.data_as<typename ValueType::c_type>();
127-
const auto& indices = sparse_csf_index.indices();
128-
129-
ARROW_ASSIGN_OR_RAISE(auto non_zero_count, tensor.CountNonZero());
130-
if (indices.back()->size() != non_zero_count) {
131-
return Status::Invalid("Mismatch between non-zero count in sparse tensor (",
132-
indices.back()->size(), ") and dense tensor (", non_zero_count,
133-
")");
134-
} else if (indices.size() != tensor.shape().size()) {
135-
return Status::Invalid("Mismatch between coordinate dimension in sparse tensor (",
136-
indices.size(), ") and tensor shape (", tensor.shape().size(),
137-
")");
138-
} else {
139-
return CheckValues<ValueType, IndexType>(sparse_csf_index, values, tensor, 0, 0, 0,
140-
sparse_csf_index.indptr()[0]->size() - 1);
141-
}
142-
}
143-
14459
// ----------------------------------------------------------------------
14560
// SparseTensorConverter for SparseCSFIndex
14661

@@ -151,8 +66,10 @@ class SparseCSFTensorConverter {
15166
MemoryPool* pool)
15267
: tensor_(tensor), index_value_type_(index_value_type), pool_(pool) {}
15368

154-
template <typename ValueType, typename IndexType>
155-
Status Convert(const ValueType&, const IndexType&) {
69+
// Note: The same type is considered for both indices and indptr during
70+
// tensor-to-CSF-tensor conversion.
71+
template <typename ValueType, typename IndexType, typename IndexPointerType>
72+
Status Convert(const ValueType&, const IndexType&, const IndexPointerType&) {
15673
using ValueCType = typename ValueType::c_type;
15774
using IndexCType = typename IndexType::c_type;
15875
RETURN_NOT_OK(::arrow::internal::CheckSparseIndexMaximumValue(index_value_type_,
@@ -235,8 +152,6 @@ class SparseCSFTensorConverter {
235152
ARROW_ASSIGN_OR_RAISE(
236153
sparse_index, SparseCSFIndex::Make(index_value_type_, indices_shapes, axis_order,
237154
indptr_buffers, indices_buffers));
238-
DCHECK_OK((ValidateSparseTensorCSFCreation<ValueType, IndexType>(*sparse_index, *data,
239-
tensor_)));
240155
return Status::OK();
241156
}
242157

@@ -353,7 +268,8 @@ Status MakeSparseCSFTensorFromTensor(const Tensor& tensor,
353268
std::shared_ptr<Buffer>* out_data) {
354269
SparseCSFTensorConverter converter(tensor, index_value_type, pool);
355270
ConverterVisitor visitor{converter};
356-
ARROW_RETURN_NOT_OK(VisitValueAndIndexType(*tensor.type(), *index_value_type, visitor));
271+
ARROW_RETURN_NOT_OK(
272+
util::VisitCSXType(*tensor.type(), *index_value_type, *index_value_type, visitor));
357273
*out_sparse_index = checked_pointer_cast<SparseIndex>(converter.sparse_index);
358274
*out_data = converter.data;
359275
return Status::OK();

0 commit comments

Comments
 (0)