Skip to content

Commit bd94e80

Browse files
author
Giorgi Lomia
committed
Did more work on arrowArray visitor WIP.
1 parent f04a126 commit bd94e80

File tree

2 files changed

+88
-98
lines changed

2 files changed

+88
-98
lines changed

libsupport/include/katana/ArrowVisitor.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,14 @@ using AcceptAllArrowTypes = std::tuple<
307307
arrow::LargeStringType, arrow::StructType, arrow::ListType,
308308
arrow::LargeListType, arrow::NullType>;
309309

310+
using AcceptAllFlatTypes = std::tuple<
311+
arrow::Int8Type, arrow::UInt8Type, arrow::Int16Type, arrow::UInt16Type,
312+
arrow::Int32Type, arrow::UInt32Type, arrow::Int64Type, arrow::UInt64Type,
313+
arrow::FloatType, arrow::DoubleType, arrow::FloatType, arrow::DoubleType,
314+
arrow::BooleanType, arrow::Date32Type, arrow::Date64Type, arrow::Time32Type,
315+
arrow::Time64Type, arrow::TimestampType, arrow::StringType,
316+
arrow::LargeStringType, arrow::NullType>;
317+
310318
template <typename... Args>
311319
using tuple_cat_t = decltype(std::tuple_cat(std::declval<Args>()...));
312320

tools/graph-stats/graph-memory-stats.cpp

Lines changed: 80 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include <string>
2323
#include <unordered_map>
2424

25-
#include <arrow/array/builder_base.h>
2625
#include <arrow/type_traits.h>
2726

2827
#include "katana/ArrowVisitor.h"
@@ -43,137 +42,122 @@ using map_string_element = std::unordered_map<std::string, std::string>;
4342
using memory_map = std::unordered_map<
4443
std::string, std::variant<map_element, map_string_element>>;
4544

45+
inline std::shared_ptr<arrow::DataType>
46+
GetArrowType(const arrow::Scalar& scalar) {
47+
return scalar.type;
48+
}
49+
50+
inline std::shared_ptr<arrow::DataType>
51+
GetArrowType(const arrow::Array& array) {
52+
return array.type();
53+
}
54+
55+
inline std::shared_ptr<arrow::DataType>
56+
GetArrowType(const arrow::ArrayBuilder* builder) {
57+
return builder->type();
58+
}
59+
4660
struct Visitor : public katana::ArrowVisitor {
4761
using ResultType = katana::Result<int64_t>;
48-
using AcceptTypes = std::tuple<
49-
arrow::Int8Type, arrow::UInt8Type, arrow::Int16Type, arrow::UInt16Type,
50-
arrow::Int32Type, arrow::UInt32Type, arrow::Int64Type, arrow::UInt64Type,
51-
arrow::FloatType, arrow::DoubleType, arrow::FloatType, arrow::DoubleType,
52-
arrow::BooleanType, arrow::Date32Type, arrow::Date64Type,
53-
arrow::Time32Type, arrow::Time64Type, arrow::TimestampType,
54-
arrow::StringType, arrow::LargeStringType, arrow::StructType,
55-
arrow::NullType>;
56-
57-
template <typename ArrowType, typename ScalarType>
62+
using AcceptTypes = std::tuple<katana::AcceptAllFlatTypes>;
63+
64+
template <typename ArrowType, typename ArrayType>
65+
arrow::enable_if_null<ArrowType, ResultType> Call(const ArrayType& scalars) {
66+
std::cout << scalars.total_values_length() << "\n";
67+
return 0;
68+
}
69+
70+
template <typename ArrowType, typename ArrayType>
5871
std::enable_if_t<
5972
arrow::is_number_type<ArrowType>::value ||
6073
arrow::is_boolean_type<ArrowType>::value ||
6174
arrow::is_temporal_type<ArrowType>::value,
6275
ResultType>
63-
Call(const ScalarType& scalar) {
64-
return scalar.value;
76+
Call(const ArrayType& scalars) {
77+
// ResultType width = 0;
78+
std::cout << scalars.total_values_length() << "\n";
79+
return 0;
6580
}
6681

67-
template <typename ArrowType, typename ScalarType>
82+
template <typename ArrowType, typename ArrayType>
6883
arrow::enable_if_string_like<ArrowType, ResultType> Call(
69-
const ScalarType& scalar) {
70-
const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
71-
auto res = (arrow::util::string_view)(*typed_scalar->value);
72-
// TODO (giorgi): make this KATANA_CHECKED
73-
// if (!res.ok()) {
74-
// return KATANA_ERROR(
75-
// katana::ErrorCode::ArrowError, "arrow builder failed append: {}",
76-
// res);
77-
// }
78-
return res;
84+
const ArrayType& scalars) {
85+
std::cout << scalars.total_values_length() << "\n";
86+
87+
return 0;
7988
}
8089

81-
ResultType AcceptFailed(const arrow::Scalar& scalar) {
90+
template <typename Param>
91+
ResultType AcceptFailed(Param&& param) {
8292
return KATANA_ERROR(
83-
katana::ErrorCode::ArrowError, "no matching type {}",
84-
scalar.type->name());
93+
"Instant functions do not accept {}", GetArrowType(param)->ToString());
8594
}
8695
};
8796

88-
// struct ToArrayVisitor : public katana::ArrowVisitor {
89-
// // Internal data and constructor
90-
// const std::shared_ptr<arrow::Array> scalars;
91-
// ToArrayVisitor(const std::shared_ptr<arrow::Array> input) : scalars(input) {}
92-
93-
// using ResultType = katana::Result<std::shared_ptr<arrow::Array>>;
94-
97+
// struct Visitor : public katana::ArrowVisitor {
98+
// const std::shared_ptr<arrow::Scalar>& scalar;
99+
// Visitor(const std::shared_ptr<arrow::Scalar>& input) : scalar(input) {}
100+
// using ResultType = katana::Result<int64_t>;
95101
// using AcceptTypes = std::tuple<katana::AcceptAllArrowTypes>;
96102

97-
// template <typename ArrowType, typename BuilderType>
98-
// arrow::enable_if_null<ArrowType, ResultType> Call(BuilderType* builder) {
99-
// return KATANA_CHECKED(builder->Finish());
103+
// template <typename ArrowType, typename WidthType>
104+
// arrow::enable_if_null<ArrowType, ResultType> Call(
105+
// const WidthType& width_tracker) {
106+
// width_tracker = 0;
107+
// return width_tracker;
100108
// }
101109

102-
// template <typename ArrowType, typename BuilderType>
110+
// template <typename ArrowType, typename WidthType>
103111
// std::enable_if_t<
104112
// arrow::is_number_type<ArrowType>::value ||
105113
// arrow::is_boolean_type<ArrowType>::value ||
106114
// arrow::is_temporal_type<ArrowType>::value,
107115
// ResultType>
108-
// Call(BuilderType* builder) {
116+
// Call(const WidthType& width_tracker) {
109117
// using ScalarType = typename arrow::TypeTraits<ArrowType>::ScalarType;
110-
111-
// KATANA_CHECKED(builder->Reserve(scalars->length()));
112-
// for (auto j = 0; j < scalars->length(); j++) {
113-
// auto scalar = *scalars->GetScalar(j);
114-
// if (scalar != nullptr && scalar->is_valid) {
115-
// const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
116-
// builder->UnsafeAppend(typed_scalar->value);
117-
// } else {
118-
// builder->UnsafeAppendNull();
119-
// }
118+
// if (scalar != nullptr && scalar->is_valid) {
119+
// const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
120+
// return typed_scalar->value;
121+
// } else {
122+
// return KATANA_ERROR(
123+
// katana::ErrorCode::ArrowError, "arrow visitor failed to read: NULL");
120124
// }
121-
// return KATANA_CHECKED(builder->Finish());
122125
// }
123126

124-
// template <typename ArrowType, typename BuilderType>
127+
// template <typename ArrowType, typename WidthType>
125128
// arrow::enable_if_string_like<ArrowType, ResultType> Call(
126-
// BuilderType* builder) {
129+
// const WidthType& width_tracker) {
127130
// using ScalarType = typename arrow::TypeTraits<ArrowType>::ScalarType;
128-
// // same as above, but with string_view and Append instead of UnsafeAppend
129-
// for (auto j = 0; j < scalars->length(); j++) {
130-
// auto scalar = *scalars->GetScalar(j);
131-
// if (scalar != nullptr && scalar->is_valid) {
132-
// // ->value->ToString() works, scalar->ToString() yields "..."
133-
// const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
134-
// if (auto res = builder->Append(
135-
// (arrow::util::string_view)(*typed_scalar->value));
136-
// !res.ok()) {
137-
// return KATANA_ERROR(
138-
// katana::ErrorCode::ArrowError, "arrow builder failed append: {}",
139-
// res);
140-
// }
141-
// } else {
142-
// if (auto res = builder->AppendNull(); !res.ok()) {
143-
// return KATANA_ERROR(
144-
// katana::ErrorCode::ArrowError,
145-
// "arrow builder failed append null: {}", res);
146-
// }
147-
// }
131+
// if (scalar != nullptr && scalar->is_valid) {
132+
// // ->value->ToString() works, scalar->ToString() yields "..."
133+
// const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
134+
// auto res = (arrow::util::string_view)(*typed_scalar->value);
135+
// return res;
136+
// } else {
137+
// return KATANA_ERROR(
138+
// katana::ErrorCode::ArrowError, "arrow visitor failed to read: NULL");
148139
// }
149-
// return KATANA_CHECKED(builder->Finish());
150140
// }
151141

152-
// template <typename ArrowType, typename BuilderType>
142+
// template <typename ArrowType, typename WidthType>
153143
// std::enable_if_t<
154144
// arrow::is_list_type<ArrowType>::value ||
155145
// arrow::is_struct_type<ArrowType>::value,
156146
// ResultType>
157-
// Call(BuilderType* builder) {
147+
// Call(const WidthType& width_tracker) {
158148
// using ScalarType = typename arrow::TypeTraits<ArrowType>::ScalarType;
159149
// // use a visitor to traverse more complex types
160-
// katana::AppendScalarToBuilder visitor(builder);
161-
// for (auto j = 0; j < scalars->length(); j++) {
162-
// auto scalar = *scalars->GetScalar(j);
163-
// if (scalar != nullptr && scalar->is_valid) {
164-
// const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
165-
// KATANA_CHECKED(visitor.Call<ArrowType>(*typed_scalar));
166-
// } else {
167-
// KATANA_CHECKED(builder->AppendNull());
168-
// }
150+
// Visitor visitor(scalar);
151+
// if (scalar != nullptr && scalar->is_valid) {
152+
// const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
153+
// KATANA_CHECKED(visitor.Call<ArrowType>(*typed_scalar));
169154
// }
170-
// return KATANA_CHECKED(builder->Finish());
171155
// }
172156

173-
// ResultType AcceptFailed(const arrow::ArrayBuilder* builder) {
157+
// ResultType AcceptFailed(const arrow::Scalar& scalar) {
174158
// return KATANA_ERROR(
175159
// katana::ErrorCode::ArrowError, "no matching type {}",
176-
// builder->type()->name());
160+
// scalar.type->name());
177161
// }
178162
// };
179163

@@ -202,19 +186,17 @@ PrintStringMapping(const std::unordered_map<std::string, std::string>& u) {
202186
std::cout << "\n";
203187
}
204188

205-
katana::Result<std::shared_ptr<arrow::Array>>
189+
void
206190
RunVisit(const std::shared_ptr<arrow::Array> scalars) {
207-
Visitor v;
208191
int64_t total = 0;
209-
for (auto j = 0; j < scalars->length(); j++) {
210-
auto s = *scalars->GetScalar(j);
211-
auto res = katana::VisitArrow(v, *s);
212-
KATANA_LOG_VASSERT(res, "unexpected errror {}", res.error());
213-
total += res.value();
214-
}
192+
Visitor v;
193+
arrow::Array* arr = scalars.get();
194+
auto res = katana::VisitArrow(v, *arr);
195+
KATANA_LOG_VASSERT(res, "unexpected errror {}", res.error());
196+
total += res.value();
215197

216-
KATANA_LOG_VASSERT(
217-
total == scalars->length(), "{} != {}", total, scalars->length());
198+
// KATANA_LOG_VASSERT(
199+
// total == scalars->length(), "{} != {}", total, scalars->length());
218200
}
219201

220202
void
@@ -258,7 +240,7 @@ GatherMemoryAllocation(
258240
alloc_size = 0;
259241
prop_size = 0;
260242
auto bit_width = arrow::bit_width(dtype->id());
261-
auto visited_arr = RunVisit(prop_field);
243+
RunVisit(prop_field);
262244
for (auto j = 0; j < prop_field->length(); j++) {
263245
if (prop_field->IsValid(j)) {
264246
auto scal_ptr = *prop_field->GetScalar(j);

0 commit comments

Comments
 (0)