2222#include < string>
2323#include < unordered_map>
2424
25- #include < arrow/array/builder_base.h>
2625#include < arrow/type_traits.h>
2726
2827#include " katana/ArrowVisitor.h"
@@ -43,137 +42,122 @@ using map_string_element = std::unordered_map<std::string, std::string>;
4342using memory_map = std::unordered_map<
4443 std::string, std::variant<map_element, map_string_element>>;
4544
45+ inline std::shared_ptr<arrow::DataType>
46+ GetArrowType (const arrow::Scalar& scalar) {
47+ return scalar.type ;
48+ }
49+
50+ inline std::shared_ptr<arrow::DataType>
51+ GetArrowType (const arrow::Array& array) {
52+ return array.type ();
53+ }
54+
55+ inline std::shared_ptr<arrow::DataType>
56+ GetArrowType (const arrow::ArrayBuilder* builder) {
57+ return builder->type ();
58+ }
59+
4660struct Visitor : public katana ::ArrowVisitor {
4761 using ResultType = katana::Result<int64_t >;
48- using AcceptTypes = std::tuple<
49- arrow::Int8Type, arrow::UInt8Type, arrow::Int16Type, arrow::UInt16Type,
50- arrow::Int32Type, arrow::UInt32Type, arrow::Int64Type, arrow::UInt64Type,
51- arrow::FloatType, arrow::DoubleType, arrow::FloatType, arrow::DoubleType,
52- arrow::BooleanType, arrow::Date32Type, arrow::Date64Type,
53- arrow::Time32Type, arrow::Time64Type, arrow::TimestampType,
54- arrow::StringType, arrow::LargeStringType, arrow::StructType,
55- arrow::NullType>;
56-
57- template <typename ArrowType, typename ScalarType>
62+ using AcceptTypes = std::tuple<katana::AcceptAllFlatTypes>;
63+
64+ template <typename ArrowType, typename ArrayType>
65+ arrow::enable_if_null<ArrowType, ResultType> Call (const ArrayType& scalars) {
66+ std::cout << scalars.total_values_length () << " \n " ;
67+ return 0 ;
68+ }
69+
70+ template <typename ArrowType, typename ArrayType>
5871 std::enable_if_t <
5972 arrow::is_number_type<ArrowType>::value ||
6073 arrow::is_boolean_type<ArrowType>::value ||
6174 arrow::is_temporal_type<ArrowType>::value,
6275 ResultType>
63- Call (const ScalarType& scalar) {
64- return scalar.value ;
76+ Call (const ArrayType& scalars) {
77+ // ResultType width = 0;
78+ std::cout << scalars.total_values_length () << " \n " ;
79+ return 0 ;
6580 }
6681
67- template <typename ArrowType, typename ScalarType >
82+ template <typename ArrowType, typename ArrayType >
6883 arrow::enable_if_string_like<ArrowType, ResultType> Call (
69- const ScalarType& scalar) {
70- const ScalarType* typed_scalar = static_cast <ScalarType*>(scalar.get ());
71- auto res = (arrow::util::string_view)(*typed_scalar->value );
72- // TODO (giorgi): make this KATANA_CHECKED
73- // if (!res.ok()) {
74- // return KATANA_ERROR(
75- // katana::ErrorCode::ArrowError, "arrow builder failed append: {}",
76- // res);
77- // }
78- return res;
84+ const ArrayType& scalars) {
85+ std::cout << scalars.total_values_length () << " \n " ;
86+
87+ return 0 ;
7988 }
8089
81- ResultType AcceptFailed (const arrow::Scalar& scalar) {
90+ template <typename Param>
91+ ResultType AcceptFailed (Param&& param) {
8292 return KATANA_ERROR (
83- katana::ErrorCode::ArrowError, " no matching type {}" ,
84- scalar.type ->name ());
93+ " Instant functions do not accept {}" , GetArrowType (param)->ToString ());
8594 }
8695};
8796
88- // struct ToArrayVisitor : public katana::ArrowVisitor {
89- // // Internal data and constructor
90- // const std::shared_ptr<arrow::Array> scalars;
91- // ToArrayVisitor(const std::shared_ptr<arrow::Array> input) : scalars(input) {}
92-
93- // using ResultType = katana::Result<std::shared_ptr<arrow::Array>>;
94-
97+ // struct Visitor : public katana::ArrowVisitor {
98+ // const std::shared_ptr<arrow::Scalar>& scalar;
99+ // Visitor(const std::shared_ptr<arrow::Scalar>& input) : scalar(input) {}
100+ // using ResultType = katana::Result<int64_t>;
95101// using AcceptTypes = std::tuple<katana::AcceptAllArrowTypes>;
96102
97- // template <typename ArrowType, typename BuilderType>
98- // arrow::enable_if_null<ArrowType, ResultType> Call(BuilderType* builder) {
99- // return KATANA_CHECKED(builder->Finish());
103+ // template <typename ArrowType, typename WidthType>
104+ // arrow::enable_if_null<ArrowType, ResultType> Call(
105+ // const WidthType& width_tracker) {
106+ // width_tracker = 0;
107+ // return width_tracker;
100108// }
101109
102- // template <typename ArrowType, typename BuilderType >
110+ // template <typename ArrowType, typename WidthType >
103111// std::enable_if_t<
104112// arrow::is_number_type<ArrowType>::value ||
105113// arrow::is_boolean_type<ArrowType>::value ||
106114// arrow::is_temporal_type<ArrowType>::value,
107115// ResultType>
108- // Call(BuilderType* builder ) {
116+ // Call(const WidthType& width_tracker ) {
109117// using ScalarType = typename arrow::TypeTraits<ArrowType>::ScalarType;
110-
111- // KATANA_CHECKED(builder->Reserve(scalars->length()));
112- // for (auto j = 0; j < scalars->length(); j++) {
113- // auto scalar = *scalars->GetScalar(j);
114- // if (scalar != nullptr && scalar->is_valid) {
115- // const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
116- // builder->UnsafeAppend(typed_scalar->value);
117- // } else {
118- // builder->UnsafeAppendNull();
119- // }
118+ // if (scalar != nullptr && scalar->is_valid) {
119+ // const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
120+ // return typed_scalar->value;
121+ // } else {
122+ // return KATANA_ERROR(
123+ // katana::ErrorCode::ArrowError, "arrow visitor failed to read: NULL");
120124// }
121- // return KATANA_CHECKED(builder->Finish());
122125// }
123126
124- // template <typename ArrowType, typename BuilderType >
127+ // template <typename ArrowType, typename WidthType >
125128// arrow::enable_if_string_like<ArrowType, ResultType> Call(
126- // BuilderType* builder ) {
129+ // const WidthType& width_tracker ) {
127130// using ScalarType = typename arrow::TypeTraits<ArrowType>::ScalarType;
128- // // same as above, but with string_view and Append instead of UnsafeAppend
129- // for (auto j = 0; j < scalars->length(); j++) {
130- // auto scalar = *scalars->GetScalar(j);
131- // if (scalar != nullptr && scalar->is_valid) {
132- // // ->value->ToString() works, scalar->ToString() yields "..."
133- // const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
134- // if (auto res = builder->Append(
135- // (arrow::util::string_view)(*typed_scalar->value));
136- // !res.ok()) {
137- // return KATANA_ERROR(
138- // katana::ErrorCode::ArrowError, "arrow builder failed append: {}",
139- // res);
140- // }
141- // } else {
142- // if (auto res = builder->AppendNull(); !res.ok()) {
143- // return KATANA_ERROR(
144- // katana::ErrorCode::ArrowError,
145- // "arrow builder failed append null: {}", res);
146- // }
147- // }
131+ // if (scalar != nullptr && scalar->is_valid) {
132+ // // ->value->ToString() works, scalar->ToString() yields "..."
133+ // const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
134+ // auto res = (arrow::util::string_view)(*typed_scalar->value);
135+ // return res;
136+ // } else {
137+ // return KATANA_ERROR(
138+ // katana::ErrorCode::ArrowError, "arrow visitor failed to read: NULL");
148139// }
149- // return KATANA_CHECKED(builder->Finish());
150140// }
151141
152- // template <typename ArrowType, typename BuilderType >
142+ // template <typename ArrowType, typename WidthType >
153143// std::enable_if_t<
154144// arrow::is_list_type<ArrowType>::value ||
155145// arrow::is_struct_type<ArrowType>::value,
156146// ResultType>
157- // Call(BuilderType* builder ) {
147+ // Call(const WidthType& width_tracker ) {
158148// using ScalarType = typename arrow::TypeTraits<ArrowType>::ScalarType;
159149// // use a visitor to traverse more complex types
160- // katana::AppendScalarToBuilder visitor(builder);
161- // for (auto j = 0; j < scalars->length(); j++) {
162- // auto scalar = *scalars->GetScalar(j);
163- // if (scalar != nullptr && scalar->is_valid) {
164- // const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
165- // KATANA_CHECKED(visitor.Call<ArrowType>(*typed_scalar));
166- // } else {
167- // KATANA_CHECKED(builder->AppendNull());
168- // }
150+ // Visitor visitor(scalar);
151+ // if (scalar != nullptr && scalar->is_valid) {
152+ // const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
153+ // KATANA_CHECKED(visitor.Call<ArrowType>(*typed_scalar));
169154// }
170- // return KATANA_CHECKED(builder->Finish());
171155// }
172156
173- // ResultType AcceptFailed(const arrow::ArrayBuilder* builder ) {
157+ // ResultType AcceptFailed(const arrow::Scalar& scalar ) {
174158// return KATANA_ERROR(
175159// katana::ErrorCode::ArrowError, "no matching type {}",
176- // builder-> type() ->name());
160+ // scalar. type->name());
177161// }
178162// };
179163
@@ -202,19 +186,17 @@ PrintStringMapping(const std::unordered_map<std::string, std::string>& u) {
202186 std::cout << " \n " ;
203187}
204188
205- katana::Result<std::shared_ptr<arrow::Array>>
189+ void
206190RunVisit (const std::shared_ptr<arrow::Array> scalars) {
207- Visitor v;
208191 int64_t total = 0 ;
209- for (auto j = 0 ; j < scalars->length (); j++) {
210- auto s = *scalars->GetScalar (j);
211- auto res = katana::VisitArrow (v, *s);
212- KATANA_LOG_VASSERT (res, " unexpected errror {}" , res.error ());
213- total += res.value ();
214- }
192+ Visitor v;
193+ arrow::Array* arr = scalars.get ();
194+ auto res = katana::VisitArrow (v, *arr);
195+ KATANA_LOG_VASSERT (res, " unexpected errror {}" , res.error ());
196+ total += res.value ();
215197
216- KATANA_LOG_VASSERT (
217- total == scalars->length (), " {} != {}" , total, scalars->length ());
198+ // KATANA_LOG_VASSERT(
199+ // total == scalars->length(), "{} != {}", total, scalars->length());
218200}
219201
220202void
@@ -258,7 +240,7 @@ GatherMemoryAllocation(
258240 alloc_size = 0 ;
259241 prop_size = 0 ;
260242 auto bit_width = arrow::bit_width (dtype->id ());
261- auto visited_arr = RunVisit (prop_field);
243+ RunVisit (prop_field);
262244 for (auto j = 0 ; j < prop_field->length (); j++) {
263245 if (prop_field->IsValid (j)) {
264246 auto scal_ptr = *prop_field->GetScalar (j);
0 commit comments