Skip to content

Commit 2fea185

Browse files
JohanMabillepre-commit-ci[bot]Alex-PLACET
authored
Added record bact import from/export to struct_array (#336)
Added record bacth import from/export to struct_array --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alexis Placet <[email protected]>
1 parent bab08c1 commit 2fea185

File tree

6 files changed

+148
-10
lines changed

6 files changed

+148
-10
lines changed

include/sparrow/array_api.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,18 @@ namespace sparrow
143143
*/
144144
SPARROW_API enum data_type data_type() const;
145145

146+
/**
147+
* @returns the name of the \ref array or an empty
148+
* string if the array does not have a name.
149+
*/
150+
SPARROW_API std::optional<std::string_view> name() const;
151+
152+
/**
153+
* Sets the name of the array to \ref name.
154+
* @param name The new name of the array.
155+
*/
156+
SPARROW_API void set_name(std::optional<std::string_view> name);
157+
146158
/**
147159
* Checks if the array has no element, i.e. whether size() == 0.
148160
*/

include/sparrow/buffer/dynamic_bitset/dynamic_bitset.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,12 @@ namespace sparrow
177177
&& std::unsigned_integral<std::ranges::range_value_t<T>>) )
178178
&& (!std::same_as<std::remove_cvref_t<T>, std::string>
179179
&& !std::same_as<std::remove_cvref_t<T>, std::string_view>
180-
&& !std::same_as<T, const char*>);
180+
&& !std::same_as<std::decay_t<T>, const char*>);
181181

182182
template <validity_bitmap_input R>
183183
validity_bitmap ensure_validity_bitmap(std::size_t size, R&& validity_input)
184184
{
185185
return detail::ensure_validity_bitmap_impl(size, std::forward<R>(validity_input));
186186
}
187187

188-
} // namespace sparrow
188+
} // namespace sparrow

include/sparrow/record_batch.hpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <vector>
2323

2424
#include "sparrow/array.hpp"
25+
#include "sparrow/layout/struct_layout/struct_array.hpp"
2526
#include "sparrow/utils/contracts.hpp"
2627

2728
#if defined(__cpp_lib_format)
@@ -65,13 +66,32 @@ namespace sparrow
6566
requires(std::convertible_to<std::ranges::range_value_t<NR>, std::string> and std::same_as<std::ranges::range_value_t<CR>, array>)
6667
record_batch(NR&& names, CR&& columns);
6768

69+
/*
70+
* Constructs a @ref record_batch from a range of arrays. Each array
71+
* must have a name: if \c arr is an array, then \c arr.name(), must
72+
* not return an empty string.
73+
*
74+
* @param comumns An input range of arrays
75+
*/
76+
template <std::ranges::input_range CR>
77+
requires std::same_as<std::ranges::range_value_t<CR>, array>
78+
record_batch(CR&& columns);
79+
6880
/**
6981
* Constructs a record_batch from a list of \c std::pair<name_type, array>.
7082
*
7183
* @param init a list of pair "name - array".
7284
*/
7385
SPARROW_API record_batch(initializer_type init);
7486

87+
/**
88+
* Construct a record batch from the given struct array.
89+
* The array must owns its internal arrow structures.
90+
*
91+
* @param ar An input struct array
92+
*/
93+
SPARROW_API record_batch(struct_array&& ar);
94+
7595
SPARROW_API record_batch(const record_batch&);
7696
SPARROW_API record_batch& operator=(const record_batch&);
7797

@@ -129,6 +149,13 @@ namespace sparrow
129149
*/
130150
SPARROW_API column_range columns() const;
131151

152+
/**
153+
* Moves the internal columns of the record batch into a struct_array
154+
* object. The record batch is empty anymore after calling this
155+
* method.
156+
*/
157+
SPARROW_API struct_array extract_struct_array();
158+
132159
private:
133160

134161
template <class U, class R>
@@ -169,6 +196,31 @@ namespace sparrow
169196
SPARROW_ASSERT_TRUE(check_consistency());
170197
}
171198

199+
namespace detail
200+
{
201+
std::vector<record_batch::name_type> get_names(const std::vector<array>& array_list)
202+
{
203+
const auto names = array_list
204+
| std::views::transform(
205+
[](const array& ar)
206+
{
207+
return ar.name().value();
208+
}
209+
);
210+
return {names.begin(), names.end()};
211+
}
212+
}
213+
214+
template <std::ranges::input_range CR>
215+
requires std::same_as<std::ranges::range_value_t<CR>, array>
216+
record_batch::record_batch(CR&& columns)
217+
: m_name_list(detail::get_names(columns))
218+
, m_array_list(to_vector<array>(std::move(columns)))
219+
{
220+
init_array_map();
221+
SPARROW_ASSERT_TRUE(check_consistency());
222+
}
223+
172224
template <class U, class R>
173225
std::vector<U> record_batch::to_vector(R&& range) const
174226
{

src/array.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ namespace sparrow
4747
return p_array->data_type();
4848
}
4949

50+
std::optional<std::string_view> array::name() const
51+
{
52+
return get_arrow_proxy().name();
53+
}
54+
55+
void array::set_name(std::optional<std::string_view> name)
56+
{
57+
get_arrow_proxy().set_name(name);
58+
}
59+
5060
bool array::empty() const
5161
{
5262
return size() == size_type(0);

src/record_batch.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,25 @@ namespace sparrow
3535
SPARROW_ASSERT_TRUE(check_consistency());
3636
}
3737

38+
record_batch::record_batch(struct_array&& arr)
39+
{
40+
SPARROW_ASSERT_TRUE(owns_arrow_array(arr));
41+
SPARROW_ASSERT_TRUE(owns_arrow_schema(arr));
42+
43+
auto [struct_arr, struct_sch] = extract_arrow_structures(std::move(arr));
44+
auto n_children = static_cast<std::size_t>(struct_arr.n_children);
45+
m_name_list.reserve(n_children);
46+
m_array_list.reserve(n_children);
47+
for (std::size_t i = 0; i < n_children; ++i)
48+
{
49+
array arr(move_array(*(struct_arr.children[i])), move_schema(*(struct_sch.children[i])));
50+
m_name_list.push_back(std::string(arr.name().value()));
51+
m_array_list.push_back(std::move(arr));
52+
}
53+
init_array_map();
54+
SPARROW_ASSERT_TRUE(check_consistency());
55+
}
56+
3857
record_batch::record_batch(const record_batch& rhs)
3958
: m_name_list(rhs.m_name_list)
4059
, m_array_list(rhs.m_array_list)
@@ -97,6 +116,16 @@ namespace sparrow
97116
return std::ranges::ref_view(m_array_list);
98117
}
99118

119+
struct_array record_batch::extract_struct_array()
120+
{
121+
for (std::size_t i = 0; i < m_name_list.size(); ++i)
122+
{
123+
m_array_list[i].set_name(m_name_list[i]);
124+
}
125+
m_array_map.clear();
126+
return struct_array(std::move(m_array_list));
127+
}
128+
100129
void record_batch::init_array_map()
101130
{
102131
m_array_map.clear();
@@ -114,6 +143,9 @@ namespace sparrow
114143
"The size of the names and of the array list must be the same"
115144
);
116145

146+
auto iter = std::find(m_name_list.begin(), m_name_list.end(), "");
147+
SPARROW_ASSERT(iter == m_name_list.end(), "A column can not have an empty name");
148+
117149
const auto unique_names = std::unordered_set<name_type>(m_name_list.begin(), m_name_list.end());
118150
SPARROW_ASSERT(unique_names.size() == m_name_list.size(), "The names of the columns must be unique");
119151

test/test_record_batch.cpp

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,22 @@ namespace sparrow
2525
{
2626
primitive_array<std::uint16_t> pr0(
2727
std::ranges::iota_view{std::size_t(0), std::size_t(data_size)}
28-
| std::views::transform(
29-
[](auto i)
30-
{
31-
return static_cast<std::uint16_t>(i);
32-
}
33-
)
28+
| std::views::transform(
29+
[](auto i)
30+
{
31+
return static_cast<std::uint16_t>(i);
32+
}
33+
),
34+
"column0"
35+
);
36+
primitive_array<std::int32_t> pr1(
37+
std::ranges::iota_view{std::int32_t(4), 4 + std::int32_t(data_size)},
38+
"column1"
39+
);
40+
primitive_array<std::int32_t> pr2(
41+
std::ranges::iota_view{std::int32_t(2), 2 + std::int32_t(data_size)},
42+
"column2"
3443
);
35-
primitive_array<std::int32_t> pr1(std::ranges::iota_view{std::int32_t(4), 4 + std::int32_t(data_size)});
36-
primitive_array<std::int32_t> pr2(std::ranges::iota_view{std::int32_t(2), 2 + std::int32_t(data_size)});
3744

3845
std::vector<array> arr_list = {array(std::move(pr0)), array(std::move(pr1)), array(std::move(pr2))};
3946
return arr_list;
@@ -71,6 +78,21 @@ namespace sparrow
7178
CHECK_EQ(record.nb_columns(), 3u);
7279
CHECK_EQ(record.nb_rows(), 10u);
7380
}
81+
82+
SUBCASE("from column list")
83+
{
84+
record_batch record(make_array_list(col_size));
85+
CHECK_EQ(record.nb_columns(), 3u);
86+
CHECK_EQ(record.nb_rows(), 10u);
87+
CHECK_FALSE(std::ranges::equal(record.names(), make_name_list()));
88+
}
89+
90+
SUBCASE("from struct array")
91+
{
92+
record_batch record0(struct_array(make_array_list(col_size)));
93+
record_batch record1(make_array_list(col_size));
94+
CHECK_EQ(record0, record1);
95+
}
7496
}
7597

7698
TEST_CASE("operator==")
@@ -164,6 +186,16 @@ namespace sparrow
164186
CHECK(res);
165187
}
166188

189+
TEST_CASE("extract_struct_array")
190+
{
191+
struct_array arr(make_array_list(col_size));
192+
struct_array control(arr);
193+
194+
record_batch r(std::move(arr));
195+
auto extr = r.extract_struct_array();
196+
CHECK_EQ(extr, control);
197+
}
198+
167199
#if defined(__cpp_lib_format)
168200
TEST_CASE("formatter")
169201
{

0 commit comments

Comments
 (0)