@@ -261,12 +261,11 @@ void NameToIdVisitor::Finish() {
261261// / \brief Visitor class for pruning schema columns based on selected field IDs.
262262// /
263263// / This visitor traverses a schema and creates a projected version containing only
264- // / the specified fields. It handles different projection modes:
265- // / - select_full_types=true: Include entire fields when their ID is selected
266- // / - select_full_types=false: Recursively project nested fields within selected structs
264+ // / the specified fields. When `select_full_types` is true, a field with all its
265+ // / sub-fields are selected if its field-id has been selected; otherwise, only leaf
266+ // / fields of selected field-ids are selected.
267267// /
268- // / \warning Error conditions that will cause projection to fail:
269- // / - Project or Select a Map with just key or value (returns InvalidArgument)
268+ // / \note It returns an error when projection is not successful.
270269class PruneColumnVisitor {
271270 public:
272271 PruneColumnVisitor (const std::unordered_set<int32_t >& selected_ids,
@@ -276,157 +275,80 @@ class PruneColumnVisitor {
276275 Result<std::shared_ptr<Type>> Visit (const std::shared_ptr<Type>& type) const {
277276 switch (type->type_id ()) {
278277 case TypeId::kStruct : {
279- auto struct_type = std::static_pointer_cast<StructType>(type);
280- return Visit (struct_type);
278+ return Visit (internal::checked_pointer_cast<StructType>(type));
281279 }
282280 case TypeId::kList : {
283- auto list_type = std::static_pointer_cast<ListType>(type);
284- return Visit (list_type);
281+ return Visit (internal::checked_pointer_cast<ListType>(type));
285282 }
286283 case TypeId::kMap : {
287- auto map_type = std::static_pointer_cast<MapType>(type);
288- return Visit (map_type);
284+ return Visit (internal::checked_pointer_cast<MapType>(type));
289285 }
290286 default : {
291- auto primitive_type = std::static_pointer_cast<PrimitiveType>(type);
292- return Visit (primitive_type);
287+ return nullptr ;
293288 }
294289 }
295290 }
296291
297- Result<std::shared_ptr<Type>> Visit (const std::shared_ptr<StructType>& type) const {
298- std::vector<std::shared_ptr<Type>> selected_types;
299- for (const auto & field : type->fields ()) {
300- if (select_full_types_ and selected_ids_.contains (field.field_id ())) {
301- selected_types.emplace_back (field.type ());
302- continue ;
303- }
304- ICEBERG_ASSIGN_OR_RAISE (auto child_result, Visit (field.type ()));
305- if (selected_ids_.contains (field.field_id ())) {
306- selected_types.emplace_back (
307- field.type ()->is_primitive () ? field.type () : std::move (child_result));
308- } else {
309- selected_types.emplace_back (std::move (child_result));
310- }
292+ Result<std::shared_ptr<Type>> Visit (const SchemaField& field) const {
293+ if (selected_ids_.contains (field.field_id ())) {
294+ return (select_full_types_ || field.type ()->is_primitive ()) ? field.type ()
295+ : Visit (field.type ());
311296 }
297+ return Visit (field.type ());
298+ }
312299
300+ static SchemaField MakeField (const SchemaField& field, std::shared_ptr<Type> type) {
301+ return {field.field_id (), std::string (field.name ()), std::move (type),
302+ field.optional (), std::string (field.doc ())};
303+ }
304+
305+ Result<std::shared_ptr<Type>> Visit (const std::shared_ptr<StructType>& type) const {
313306 bool same_types = true ;
314307 std::vector<SchemaField> selected_fields;
315- const auto & fields = type->fields ();
316- for (size_t i = 0 ; i < fields.size (); i++) {
317- if (fields[i].type () == selected_types[i]) {
318- selected_fields.emplace_back (std::move (fields[i]));
319- } else if (selected_types[i]) {
320- same_types = false ;
321- selected_fields.emplace_back (fields[i].field_id (), std::string (fields[i].name ()),
322- std::move (selected_types[i]), fields[i].optional (),
323- std::string (fields[i].doc ()));
308+ for (const auto & field : type->fields ()) {
309+ ICEBERG_ASSIGN_OR_RAISE (auto child_type, Visit (field));
310+ if (child_type) {
311+ same_types = same_types && (child_type == field.type ());
312+ selected_fields.emplace_back (MakeField (field, std::move (child_type)));
324313 }
325314 }
326315
327- if (!selected_fields.empty ()) {
328- if (same_types && selected_fields.size () == fields.size ()) {
329- return type;
330- } else {
331- return std::make_shared<StructType>(std::move (selected_fields));
332- }
316+ if (selected_fields.empty ()) {
317+ return nullptr ;
318+ } else if (same_types and selected_fields.size () == type->fields ().size ()) {
319+ return type;
333320 }
334-
335- return nullptr ;
321+ return std::make_shared<StructType>(std::move (selected_fields));
336322 }
337323
338324 Result<std::shared_ptr<Type>> Visit (const std::shared_ptr<ListType>& type) const {
339- const auto & element_field = type->fields ()[0 ];
340- if (select_full_types_ and selected_ids_.contains (element_field.field_id ())) {
325+ const auto & elem_field = type->fields ()[0 ];
326+ ICEBERG_ASSIGN_OR_RAISE (auto elem_type, Visit (elem_field));
327+ if (elem_type == nullptr ) {
328+ return nullptr ;
329+ } else if (elem_type == elem_field.type ()) {
341330 return type;
342331 }
343-
344- ICEBERG_ASSIGN_OR_RAISE (auto child_result, Visit (element_field.type ()));
345-
346- std::shared_ptr<Type> out;
347- if (selected_ids_.contains (element_field.field_id ())) {
348- if (element_field.type ()->is_primitive ()) {
349- out = std::make_shared<ListType>(element_field);
350- } else {
351- ICEBERG_ASSIGN_OR_RAISE (out, ProjectList (element_field, std::move (child_result)));
352- }
353- } else if (child_result) {
354- ICEBERG_ASSIGN_OR_RAISE (out, ProjectList (element_field, std::move (child_result)));
355- }
356- return out;
332+ return std::make_shared<ListType>(MakeField (elem_field, std::move (elem_type)));
357333 }
358334
359335 Result<std::shared_ptr<Type>> Visit (const std::shared_ptr<MapType>& type) const {
360336 const auto & key_field = type->fields ()[0 ];
361337 const auto & value_field = type->fields ()[1 ];
338+ ICEBERG_ASSIGN_OR_RAISE (auto key_type, Visit (key_field));
339+ ICEBERG_ASSIGN_OR_RAISE (auto value_type, Visit (value_field));
362340
363- if (select_full_types_ and selected_ids_.contains (key_field.field_id ()) and
364- selected_ids_.contains (value_field.field_id ())) {
365- return type;
366- }
367-
368- ICEBERG_ASSIGN_OR_RAISE (auto key_result, Visit (key_field.type ()));
369- ICEBERG_ASSIGN_OR_RAISE (auto value_result, Visit (value_field.type ()));
370-
371- if (selected_ids_.contains (value_field.field_id ()) and
372- value_field.type ()->is_primitive ()) {
373- value_result = value_field.type ();
374- }
375- if (selected_ids_.contains (key_field.field_id ()) and
376- key_field.type ()->is_primitive ()) {
377- key_result = key_field.type ();
378- }
379-
380- if (!key_result && !value_result) {
341+ if (key_type == nullptr && value_type == nullptr ) {
381342 return nullptr ;
343+ } else if (value_type == value_field.type () &&
344+ (key_type == key_field.type () || key_type == nullptr )) {
345+ return type;
346+ } else if (value_type == nullptr ) {
347+ return InvalidArgument (" Cannot project Map without value field" );
382348 }
383-
384- if (!key_result || !value_result) {
385- return InvalidArgument (
386- " Cannot project Map with only key or value: key={}, value={}" ,
387- key_result ? " present" : " null" , value_result ? " present" : " null" );
388- }
389-
390- ICEBERG_ASSIGN_OR_RAISE (auto out,
391- ProjectMap (key_field, value_field, key_result, value_result));
392- return out;
393- }
394-
395- Result<std::shared_ptr<Type>> Visit (const std::shared_ptr<PrimitiveType>& type) const {
396- return nullptr ;
397- }
398-
399- Result<std::shared_ptr<Type>> ProjectList (const SchemaField& element_field,
400- std::shared_ptr<Type> child_result) const {
401- if (!child_result) {
402- return nullptr ;
403- }
404- if (element_field.type () == child_result) {
405- return std::make_shared<ListType>(element_field);
406- }
407- return std::make_shared<ListType>(element_field.field_id (), child_result,
408- element_field.optional ());
409- }
410-
411- Result<std::shared_ptr<Type>> ProjectMap (const SchemaField& key_field,
412- const SchemaField& value_field,
413- std::shared_ptr<Type> key_result,
414- std::shared_ptr<Type> value_result) const {
415- SchemaField projected_key_field = key_field;
416- if (key_field.type () != key_result) {
417- projected_key_field =
418- SchemaField (key_field.field_id (), std::string (key_field.name ()), key_result,
419- key_field.optional ());
420- }
421-
422- SchemaField projected_value_field = value_field;
423- if (value_field.type () != value_result) {
424- projected_value_field =
425- SchemaField (value_field.field_id (), std::string (value_field.name ()),
426- value_result, value_field.optional ());
427- }
428-
429- return std::make_shared<MapType>(projected_key_field, projected_value_field);
349+ return std::make_shared<MapType>(
350+ (key_type == nullptr ? key_field : MakeField (key_field, std::move (key_type))),
351+ MakeField (value_field, std::move (value_type)));
430352 }
431353
432354 private:
0 commit comments