Skip to content

Commit f7bcd93

Browse files
author
nullccxsy
committed
feat: implement schema selection and projection methods
- Added select and project methods to the Schema class for creating projection schemas based on specified field names or IDs. - Introduced PruneColumnVisitor to handle the logic for selecting and projecting fields, including support for nested structures.
1 parent 88f5520 commit f7bcd93

File tree

3 files changed

+1077
-7
lines changed

3 files changed

+1077
-7
lines changed

src/iceberg/schema.cc

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,4 +257,259 @@ void NameToIdVisitor::Finish() {
257257
}
258258
}
259259

260+
class PruneColumnVisitor {
261+
public:
262+
explicit PruneColumnVisitor(const std::unordered_set<int32_t>& selected_ids,
263+
bool select_full_types = false);
264+
Status Visit(const ListType& type);
265+
Status Visit(const MapType& type);
266+
Status Visit(const StructType& type);
267+
Status Visit(const PrimitiveType& type);
268+
std::shared_ptr<const Type> GetResult() const;
269+
void SetResult(std::shared_ptr<const Type> result);
270+
Status ProjectList(const SchemaField& element,
271+
std::shared_ptr<const Type> element_result);
272+
Status ProjectMap(const SchemaField& key_field, const SchemaField& value_field,
273+
std::shared_ptr<const Type> value_result);
274+
275+
private:
276+
const std::unordered_set<int32_t>& selected_ids_;
277+
bool select_full_types_;
278+
std::shared_ptr<const Type> result_;
279+
};
280+
281+
Result<std::shared_ptr<const Schema>> Schema::select(
282+
const std::vector<std::string>& names, bool case_sensitive) const {
283+
return internalSelect(names, case_sensitive);
284+
}
285+
286+
Result<std::shared_ptr<const Schema>> Schema::select(
287+
const std::initializer_list<std::string>& names, bool case_sensitive) const {
288+
return internalSelect(std::vector<std::string>(names), case_sensitive);
289+
}
290+
291+
Result<std::shared_ptr<const Schema>> Schema::internalSelect(
292+
const std::vector<std::string>& names, bool case_sensitive) const {
293+
const std::string ALL_COLUMNS = "*";
294+
if (std::ranges::find(names, ALL_COLUMNS) != names.end()) {
295+
return shared_from_this();
296+
}
297+
298+
std::unordered_set<int32_t> selected_ids;
299+
for (const auto& name : names) {
300+
ICEBERG_ASSIGN_OR_RAISE(auto result, FindFieldByName(name, case_sensitive));
301+
if (result.has_value()) {
302+
selected_ids.insert(result.value().get().field_id());
303+
}
304+
}
305+
306+
PruneColumnVisitor visitor(selected_ids, /*select_full_types=*/true);
307+
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*this, &visitor));
308+
309+
auto projected_type = visitor.GetResult();
310+
if (!projected_type) {
311+
return std::make_shared<Schema>(std::vector<SchemaField>{}, schema_id_);
312+
}
313+
314+
if (projected_type->type_id() != TypeId::kStruct) {
315+
return InvalidSchema("Projected type must be a struct type");
316+
}
317+
318+
const auto& projected_struct =
319+
internal::checked_cast<const StructType&>(*projected_type);
320+
321+
std::vector<SchemaField> fields_vec(projected_struct.fields().begin(),
322+
projected_struct.fields().end());
323+
return std::make_shared<Schema>(std::move(fields_vec), schema_id_);
324+
}
325+
326+
Result<std::shared_ptr<const Schema>> Schema::project(
327+
std::unordered_set<int32_t>& selected_ids) const {
328+
PruneColumnVisitor visitor(selected_ids, /*select_full_types=*/false);
329+
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*this, &visitor));
330+
331+
auto projected_type = visitor.GetResult();
332+
if (!projected_type) {
333+
return std::make_shared<Schema>(std::vector<SchemaField>{}, schema_id_);
334+
}
335+
336+
if (projected_type->type_id() != TypeId::kStruct) {
337+
return InvalidSchema("Projected type must be a struct type");
338+
}
339+
340+
const auto& projected_struct =
341+
internal::checked_cast<const StructType&>(*projected_type);
342+
std::vector<SchemaField> fields_vec(projected_struct.fields().begin(),
343+
projected_struct.fields().end());
344+
return std::make_shared<Schema>(std::move(fields_vec), schema_id_);
345+
}
346+
347+
PruneColumnVisitor::PruneColumnVisitor(const std::unordered_set<int32_t>& selected_ids,
348+
bool select_full_types)
349+
: selected_ids_(selected_ids), select_full_types_(select_full_types) {}
350+
351+
std::shared_ptr<const Type> PruneColumnVisitor::GetResult() const { return result_; }
352+
353+
void PruneColumnVisitor::SetResult(std::shared_ptr<const Type> result) {
354+
result_ = std::move(result);
355+
}
356+
357+
Status PruneColumnVisitor::Visit(const StructType& type) {
358+
std::vector<std::shared_ptr<const Type>> selected_types;
359+
const auto& fields = type.fields();
360+
for (const auto& field : fields) {
361+
PruneColumnVisitor field_visitor(selected_ids_, select_full_types_);
362+
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*field.type(), &field_visitor));
363+
auto result = field_visitor.GetResult();
364+
if (selected_ids_.contains(field.field_id())) {
365+
// select
366+
if (select_full_types_) {
367+
selected_types.emplace_back(field.type());
368+
} else if (field.type()->type_id() == TypeId::kStruct) {
369+
// project(kstruct)
370+
if (!result) {
371+
result = std::make_shared<StructType>(std::vector<SchemaField>{});
372+
}
373+
selected_types.emplace_back(std::move(result));
374+
} else {
375+
// project(list, map, primitive)
376+
if (!field.type()->is_primitive()) {
377+
return InvalidArgument(
378+
"Cannot explicitly project List or Map types, {}:{} of type {} was "
379+
"selected",
380+
field.field_id(), field.name(), field.type()->ToString());
381+
}
382+
selected_types.emplace_back(field.type());
383+
}
384+
} else if (result) {
385+
// project, select
386+
selected_types.emplace_back(std::move(result));
387+
} else {
388+
// project, select
389+
selected_types.emplace_back(nullptr);
390+
}
391+
}
392+
393+
bool same_types = true;
394+
std::vector<SchemaField> selected_fields;
395+
for (size_t i = 0; i < fields.size(); i++) {
396+
if (fields[i].type() == selected_types[i]) {
397+
selected_fields.emplace_back(fields[i]);
398+
} else if (selected_types[i]) {
399+
same_types = false;
400+
selected_fields.emplace_back(fields[i].field_id(), std::string(fields[i].name()),
401+
std::const_pointer_cast<Type>(selected_types[i]),
402+
fields[i].optional(), std::string(fields[i].doc()));
403+
}
404+
}
405+
406+
if (!selected_fields.empty()) {
407+
if (selected_fields.size() == fields.size() && same_types) {
408+
result_ = std::make_shared<StructType>(type);
409+
} else {
410+
result_ = std::make_shared<StructType>(std::move(selected_fields));
411+
}
412+
}
413+
414+
return {};
415+
}
416+
417+
Status PruneColumnVisitor::Visit(const ListType& type) {
418+
const auto& element_field = type.fields()[0];
419+
420+
PruneColumnVisitor element_visitor(selected_ids_, select_full_types_);
421+
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*element_field.type(), &element_visitor));
422+
423+
auto element_result = element_visitor.GetResult();
424+
425+
if (selected_ids_.contains(element_field.field_id())) {
426+
if (select_full_types_) {
427+
result_ = std::make_shared<ListType>(element_field);
428+
} else if (element_field.type()->type_id() == TypeId::kStruct) {
429+
ICEBERG_RETURN_UNEXPECTED(ProjectList(element_field, element_result));
430+
} else {
431+
if (!element_field.type()->is_primitive()) {
432+
return InvalidArgument(
433+
"Cannot explicitly project List or Map types, List element {} of type {} was "
434+
"selected",
435+
element_field.field_id(), element_field.name());
436+
}
437+
result_ = std::make_shared<ListType>(element_field);
438+
}
439+
} else if (element_result) {
440+
ICEBERG_RETURN_UNEXPECTED(ProjectList(element_field, element_result));
441+
}
442+
443+
return {};
444+
}
445+
446+
Status PruneColumnVisitor::Visit(const MapType& type) {
447+
const auto& key_field = type.fields()[0];
448+
const auto& value_field = type.fields()[1];
449+
450+
PruneColumnVisitor key_visitor(selected_ids_, select_full_types_);
451+
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*key_field.type(), &key_visitor));
452+
auto key_result = key_visitor.GetResult();
453+
454+
PruneColumnVisitor value_visitor(selected_ids_, select_full_types_);
455+
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*value_field.type(), &value_visitor));
456+
auto value_result = value_visitor.GetResult();
457+
458+
if (selected_ids_.contains(value_field.field_id())) {
459+
if (select_full_types_) {
460+
result_ = std::make_shared<MapType>(type);
461+
} else if (value_field.type()->type_id() == TypeId::kStruct) {
462+
ICEBERG_RETURN_UNEXPECTED(ProjectMap(key_field, value_field, value_result));
463+
} else {
464+
if (!value_field.type()->is_primitive()) {
465+
return InvalidArgument(
466+
"Cannot explicitly project List or Map types, Map value {} of type {} was "
467+
"selected",
468+
value_field.field_id(), type.ToString());
469+
}
470+
result_ = std::make_shared<MapType>(type);
471+
}
472+
} else if (value_result) {
473+
ICEBERG_RETURN_UNEXPECTED(ProjectMap(key_field, value_field, value_result));
474+
} else if (selected_ids_.contains(key_field.field_id())) {
475+
result_ = std::make_shared<MapType>(type);
476+
}
477+
478+
return {};
479+
}
480+
481+
Status PruneColumnVisitor::Visit(const PrimitiveType& type) { return {}; }
482+
483+
Status PruneColumnVisitor::ProjectList(const SchemaField& element_field,
484+
std::shared_ptr<const Type> element_result) {
485+
if (!element_result) {
486+
return InvalidArgument("Cannot project a list when the element result is null");
487+
}
488+
if (element_field.type() == element_result) {
489+
result_ = std::make_shared<ListType>(element_field);
490+
} else {
491+
result_ = std::make_shared<ListType>(element_field.field_id(),
492+
std::const_pointer_cast<Type>(element_result),
493+
element_field.optional());
494+
}
495+
return {};
496+
}
497+
498+
Status PruneColumnVisitor::ProjectMap(const SchemaField& key_field,
499+
const SchemaField& value_field,
500+
std::shared_ptr<const Type> value_result) {
501+
if (!value_result) {
502+
return InvalidArgument("Attempted to project a map without a defined map value type");
503+
}
504+
if (value_field.type() == value_result) {
505+
result_ = std::make_shared<MapType>(key_field, value_field);
506+
} else {
507+
result_ = std::make_shared<MapType>(
508+
key_field,
509+
SchemaField(value_field.field_id(), std::string(value_field.name()),
510+
std::const_pointer_cast<Type>(value_result), value_field.optional()));
511+
}
512+
return {};
513+
}
514+
260515
} // namespace iceberg

src/iceberg/schema.h

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <cstdint>
2727
#include <optional>
2828
#include <string>
29+
#include <unordered_set>
2930
#include <vector>
3031

3132
#include "iceberg/iceberg_export.h"
@@ -41,7 +42,8 @@ namespace iceberg {
4142
/// A schema is a list of typed columns, along with a unique integer ID. A
4243
/// Table may have different schemas over its lifetime due to schema
4344
/// evolution.
44-
class ICEBERG_EXPORT Schema : public StructType {
45+
class ICEBERG_EXPORT Schema : public StructType,
46+
public std::enable_shared_from_this<Schema> {
4547
public:
4648
static constexpr int32_t kInitialSchemaId = 0;
4749

@@ -52,9 +54,9 @@ class ICEBERG_EXPORT Schema : public StructType {
5254
///
5355
/// A schema is identified by a unique ID for the purposes of schema
5456
/// evolution.
55-
[[nodiscard]] std::optional<int32_t> schema_id() const;
57+
std::optional<int32_t> schema_id() const;
5658

57-
[[nodiscard]] std::string ToString() const override;
59+
std::string ToString() const override;
5860

5961
/// \brief Find the SchemaField by field name.
6062
///
@@ -65,19 +67,41 @@ class ICEBERG_EXPORT Schema : public StructType {
6567
/// canonical name 'm.value.x'
6668
/// FIXME: Currently only handles ASCII lowercase conversion; extend to support
6769
/// non-ASCII characters (e.g., using std::towlower or ICU)
68-
[[nodiscard]] Result<std::optional<std::reference_wrapper<const SchemaField>>>
69-
FindFieldByName(std::string_view name, bool case_sensitive = true) const;
70+
Result<std::optional<std::reference_wrapper<const SchemaField>>> FindFieldByName(
71+
std::string_view name, bool case_sensitive = true) const;
7072

7173
/// \brief Find the SchemaField by field id.
72-
[[nodiscard]] Result<std::optional<std::reference_wrapper<const SchemaField>>>
73-
FindFieldById(int32_t field_id) const;
74+
Result<std::optional<std::reference_wrapper<const SchemaField>>> FindFieldById(
75+
int32_t field_id) const;
76+
77+
/// \brief Creates a projection schema for a subset of columns, selected by name.
78+
Result<std::shared_ptr<const Schema>> select(const std::vector<std::string>& names,
79+
bool case_sensitive = true) const;
80+
81+
/// \brief Creates a projection schema for a subset of columns, selected by name.
82+
Result<std::shared_ptr<const Schema>> select(
83+
const std::initializer_list<std::string>& names, bool case_sensitive = true) const;
84+
85+
/// \brief Creates a projection schema for a subset of columns, selected by name.
86+
template <typename... Args>
87+
Result<std::shared_ptr<const Schema>> select(Args&&... names,
88+
bool case_sensitive = true) const {
89+
static_assert((std::is_convertible_v<Args, std::string> && ...),
90+
"All arguments must be convertible to std::string");
91+
return select({std::string(names)...}, case_sensitive);
92+
}
93+
94+
Result<std::shared_ptr<const Schema>> project(std::unordered_set<int32_t>& ids) const;
7495

7596
friend bool operator==(const Schema& lhs, const Schema& rhs) { return lhs.Equals(rhs); }
7697

7798
private:
7899
/// \brief Compare two schemas for equality.
79100
[[nodiscard]] bool Equals(const Schema& other) const;
80101

102+
Result<std::shared_ptr<const Schema>> internalSelect(
103+
const std::vector<std::string>& names, bool case_sensitive) const;
104+
81105
// TODO(nullccxsy): Address potential concurrency issues in lazy initialization (e.g.,
82106
// use std::call_once)
83107
Status InitIdToFieldMap() const;

0 commit comments

Comments
 (0)