Skip to content

Commit 532bf48

Browse files
Added incremental record_batch API (#345)
* Added incremental record_batch API * Added missing break * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9e76f85 commit 532bf48

File tree

3 files changed

+92
-17
lines changed

3 files changed

+92
-17
lines changed

include/sparrow/record_batch.hpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,18 +156,36 @@ namespace sparrow
156156
*/
157157
SPARROW_API struct_array extract_struct_array();
158158

159+
/**
160+
* Appends the array \ref column to the record batch, and maps it with
161+
* \ref name.
162+
*
163+
* @param name The name of the column to append.
164+
* @param column The array to append.
165+
*/
166+
SPARROW_API void add_column(name_type name, array column);
167+
168+
/**
169+
* Appends the array \ref column to the record batch, and maps it to
170+
* its internal name. \ref column must have a name.
171+
*
172+
* @param column The array to append.
173+
*/
174+
SPARROW_API void add_column(array column);
175+
159176
private:
160177

161178
template <class U, class R>
162179
[[nodiscard]] std::vector<U> to_vector(R&& range) const;
163180

164-
SPARROW_API void init_array_map();
181+
SPARROW_API void update_array_map_cache() const;
165182

166183
[[nodiscard]] SPARROW_API bool check_consistency() const;
167184

168185
std::vector<name_type> m_name_list;
169186
std::vector<array> m_array_list;
170-
std::unordered_map<name_type, array*> m_array_map;
187+
mutable std::unordered_map<name_type, const array*> m_array_map;
188+
mutable bool m_dirty_map = true;
171189
};
172190

173191
/**
@@ -192,8 +210,7 @@ namespace sparrow
192210
: m_name_list(to_vector<name_type>(std::move(names)))
193211
, m_array_list(to_vector<array>(std::move(columns)))
194212
{
195-
init_array_map();
196-
SPARROW_ASSERT_TRUE(check_consistency());
213+
update_array_map_cache();
197214
}
198215

199216
namespace detail
@@ -217,8 +234,7 @@ namespace sparrow
217234
: m_name_list(detail::get_names(columns))
218235
, m_array_list(to_vector<array>(std::move(columns)))
219236
{
220-
init_array_map();
221-
SPARROW_ASSERT_TRUE(check_consistency());
237+
update_array_map_cache();
222238
}
223239

224240
template <class U, class R>

src/record_batch.cpp

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ namespace sparrow
3131
m_array_list.push_back(std::move(array));
3232
}
3333

34-
init_array_map();
35-
SPARROW_ASSERT_TRUE(check_consistency());
34+
update_array_map_cache();
3635
}
3736

3837
record_batch::record_batch(struct_array&& arr)
@@ -50,22 +49,21 @@ namespace sparrow
5049
m_name_list.push_back(std::string(arr.name().value()));
5150
m_array_list.push_back(std::move(arr));
5251
}
53-
init_array_map();
54-
SPARROW_ASSERT_TRUE(check_consistency());
52+
update_array_map_cache();
5553
}
5654

5755
record_batch::record_batch(const record_batch& rhs)
5856
: m_name_list(rhs.m_name_list)
5957
, m_array_list(rhs.m_array_list)
6058
{
61-
init_array_map();
59+
update_array_map_cache();
6260
}
6361

6462
record_batch& record_batch::operator=(const record_batch& rhs)
6563
{
6664
m_name_list = rhs.m_name_list;
6765
m_array_list = rhs.m_array_list;
68-
init_array_map();
66+
update_array_map_cache();
6967
return *this;
7068
}
7169

@@ -81,6 +79,7 @@ namespace sparrow
8179

8280
bool record_batch::contains_column(const name_type& name) const
8381
{
82+
update_array_map_cache();
8483
return m_array_map.contains(name);
8584
}
8685

@@ -92,6 +91,7 @@ namespace sparrow
9291

9392
const array& record_batch::get_column(const name_type& name) const
9493
{
94+
update_array_map_cache();
9595
const auto iter = m_array_map.find(name);
9696
if (iter == m_array_map.end())
9797
{
@@ -126,14 +126,40 @@ namespace sparrow
126126
return struct_array(std::move(m_array_list));
127127
}
128128

129-
void record_batch::init_array_map()
129+
void record_batch::add_column(name_type name, array column)
130130
{
131-
m_array_map.clear();
132-
m_array_map.reserve(m_name_list.size());
133-
for (size_t i = 0; i < m_name_list.size(); ++i)
131+
m_name_list.push_back(std::move(name));
132+
m_array_list.push_back(std::move(column));
133+
m_dirty_map = true;
134+
}
135+
136+
void record_batch::add_column(array column)
137+
{
138+
auto opt_col_name = column.name();
139+
SPARROW_ASSERT_TRUE(opt_col_name.has_value());
140+
std::string name(opt_col_name.value());
141+
add_column(std::move(name), std::move(column));
142+
}
143+
144+
void record_batch::update_array_map_cache() const
145+
{
146+
if (!m_dirty_map)
134147
{
135-
m_array_map.try_emplace(m_name_list[i], &(m_array_list[i]));
148+
return;
136149
}
150+
151+
// Columns can only be appened, so update the map
152+
// in reverse order and stops when it finds a name
153+
// already contained in it.
154+
for (std::size_t i = m_name_list.size(); i != 0; --i)
155+
{
156+
if (!m_array_map.try_emplace(m_name_list[i - 1], &(m_array_list[i - 1])).second)
157+
{
158+
break;
159+
}
160+
}
161+
m_dirty_map = false;
162+
SPARROW_ASSERT_TRUE(check_consistency());
137163
}
138164

139165
bool record_batch::check_consistency() const

test/test_record_batch.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,39 @@ namespace sparrow
196196
CHECK_EQ(extr, control);
197197
}
198198

199+
TEST_CASE("add_column")
200+
{
201+
auto record = make_record_batch(col_size);
202+
primitive_array<std::int32_t> pr3(
203+
std::ranges::iota_view{std::int32_t(3), 3 + std::int32_t(col_size)},
204+
"column3"
205+
);
206+
207+
auto ctrl = pr3;
208+
209+
record.add_column(array(std::move(pr3)));
210+
std::vector<std::string> ctrl_name_list = make_name_list();
211+
ctrl_name_list.push_back("column3");
212+
std::vector<std::string> name_list(record.names().begin(), record.names().end());
213+
CHECK_EQ(name_list, ctrl_name_list);
214+
215+
const auto& col3 = record.get_column(3);
216+
bool res = col3.visit(
217+
[&ctrl]<typename T>(const T& arg)
218+
{
219+
if constexpr (std::same_as<primitive_array<std::int32_t>, T>)
220+
{
221+
return arg == ctrl;
222+
}
223+
else
224+
{
225+
return false;
226+
}
227+
}
228+
);
229+
CHECK(res);
230+
}
231+
199232
#if defined(__cpp_lib_format)
200233
TEST_CASE("formatter")
201234
{

0 commit comments

Comments
 (0)