@@ -261,21 +261,20 @@ class PruneColumnVisitor {
261261 public:
262262 explicit PruneColumnVisitor (const std::unordered_set<int32_t >& selected_ids,
263263 bool select_full_types = false );
264- Status Visit (const ListType& type);
265- Status Visit (const MapType& type);
266- Status Visit (const StructType& type);
267- Status Visit (const PrimitiveType& type);
268- std::shared_ptr<const Type> GetResult () const ;
269- void SetResult (std::shared_ptr<const Type> result);
264+ Status Visit (const ListType& type, std::shared_ptr<const Type>& result);
265+ Status Visit (const MapType& type, std::shared_ptr<const Type>& result);
266+ Status Visit (const StructType& type, std::shared_ptr<const Type>& result);
267+ Status Visit (const PrimitiveType& type, std::shared_ptr<const Type>& result);
270268 Status ProjectList (const SchemaField& element,
271- std::shared_ptr<const Type> element_result);
269+ std::shared_ptr<const Type>& child_result,
270+ std::shared_ptr<const Type>& result);
272271 Status ProjectMap (const SchemaField& key_field, const SchemaField& value_field,
273- std::shared_ptr<const Type> value_result);
272+ std::shared_ptr<const Type>& value_result,
273+ std::shared_ptr<const Type>& result);
274274
275275 private:
276276 const std::unordered_set<int32_t >& selected_ids_;
277277 bool select_full_types_;
278- std::shared_ptr<const Type> result_;
279278};
280279
281280Result<std::shared_ptr<const Schema>> Schema::select (
@@ -303,20 +302,19 @@ Result<std::shared_ptr<const Schema>> Schema::internalSelect(
303302 }
304303 }
305304
305+ std::shared_ptr<const Type> result;
306306 PruneColumnVisitor visitor (selected_ids, /* select_full_types=*/ true );
307- ICEBERG_RETURN_UNEXPECTED (VisitTypeInline (*this , &visitor));
307+ ICEBERG_RETURN_UNEXPECTED (VisitTypeInline (*this , &visitor, result ));
308308
309- auto projected_type = visitor.GetResult ();
310- if (!projected_type) {
309+ if (!result) {
311310 return std::make_shared<Schema>(std::vector<SchemaField>{}, schema_id_);
312311 }
313312
314- if (projected_type ->type_id () != TypeId::kStruct ) {
313+ if (result ->type_id () != TypeId::kStruct ) {
315314 return InvalidSchema (" Projected type must be a struct type" );
316315 }
317316
318- const auto & projected_struct =
319- internal::checked_cast<const StructType&>(*projected_type);
317+ const auto & projected_struct = internal::checked_cast<const StructType&>(*result);
320318
321319 std::vector<SchemaField> fields_vec (projected_struct.fields ().begin (),
322320 projected_struct.fields ().end ());
@@ -326,19 +324,19 @@ Result<std::shared_ptr<const Schema>> Schema::internalSelect(
326324Result<std::shared_ptr<const Schema>> Schema::project (
327325 std::unordered_set<int32_t >& selected_ids) const {
328326 PruneColumnVisitor visitor (selected_ids, /* select_full_types=*/ false );
329- ICEBERG_RETURN_UNEXPECTED (VisitTypeInline (*this , &visitor));
330327
331- auto projected_type = visitor.GetResult ();
332- if (!projected_type) {
328+ std::shared_ptr<const Type> result;
329+ ICEBERG_RETURN_UNEXPECTED (VisitTypeInline (*this , &visitor, result));
330+
331+ if (!result) {
333332 return std::make_shared<Schema>(std::vector<SchemaField>{}, schema_id_);
334333 }
335334
336- if (projected_type ->type_id () != TypeId::kStruct ) {
335+ if (result ->type_id () != TypeId::kStruct ) {
337336 return InvalidSchema (" Projected type must be a struct type" );
338337 }
339338
340- const auto & projected_struct =
341- internal::checked_cast<const StructType&>(*projected_type);
339+ const auto & projected_struct = internal::checked_cast<const StructType&>(*result);
342340 std::vector<SchemaField> fields_vec (projected_struct.fields ().begin (),
343341 projected_struct.fields ().end ());
344342 return std::make_shared<Schema>(std::move (fields_vec), schema_id_);
@@ -348,29 +346,23 @@ PruneColumnVisitor::PruneColumnVisitor(const std::unordered_set<int32_t>& select
348346 bool select_full_types)
349347 : selected_ids_(selected_ids), select_full_types_(select_full_types) {}
350348
351- std::shared_ptr<const Type> PruneColumnVisitor::GetResult () const { return result_; }
352-
353- void PruneColumnVisitor::SetResult (std::shared_ptr<const Type> result) {
354- result_ = std::move (result);
355- }
356-
357- Status PruneColumnVisitor::Visit (const StructType& type) {
349+ Status PruneColumnVisitor::Visit (const StructType& type,
350+ std::shared_ptr<const Type>& result) {
358351 std::vector<std::shared_ptr<const Type>> selected_types;
359352 const auto & fields = type.fields ();
360353 for (const auto & field : fields) {
361- PruneColumnVisitor field_visitor (selected_ids_, select_full_types_);
362- ICEBERG_RETURN_UNEXPECTED (VisitTypeInline (*field.type (), &field_visitor));
363- auto result = field_visitor.GetResult ();
354+ std::shared_ptr<const Type> child_result;
355+ ICEBERG_RETURN_UNEXPECTED (VisitTypeInline (*field.type (), this , child_result));
364356 if (selected_ids_.contains (field.field_id ())) {
365357 // select
366358 if (select_full_types_) {
367359 selected_types.emplace_back (field.type ());
368360 } else if (field.type ()->type_id () == TypeId::kStruct ) {
369361 // project(kstruct)
370- if (!result ) {
371- result = std::make_shared<StructType>(std::vector<SchemaField>{});
362+ if (!child_result ) {
363+ child_result = std::make_shared<StructType>(std::vector<SchemaField>{});
372364 }
373- selected_types.emplace_back (std::move (result ));
365+ selected_types.emplace_back (std::move (child_result ));
374366 } else {
375367 // project(list, map, primitive)
376368 if (!field.type ()->is_primitive ()) {
@@ -381,9 +373,9 @@ Status PruneColumnVisitor::Visit(const StructType& type) {
381373 }
382374 selected_types.emplace_back (field.type ());
383375 }
384- } else if (result ) {
376+ } else if (child_result ) {
385377 // project, select
386- selected_types.emplace_back (std::move (result ));
378+ selected_types.emplace_back (std::move (child_result ));
387379 } else {
388380 // project, select
389381 selected_types.emplace_back (nullptr );
@@ -404,107 +396,116 @@ Status PruneColumnVisitor::Visit(const StructType& type) {
404396 }
405397
406398 if (!selected_fields.empty ()) {
407- if (selected_fields.size () == fields.size () && same_types ) {
408- result_ = std::make_shared<StructType>(type);
399+ if (same_types && selected_fields.size () == fields.size ()) {
400+ result = std::make_shared<StructType>(type);
409401 } else {
410- result_ = std::make_shared<StructType>(std::move (selected_fields));
402+ result = std::make_shared<StructType>(std::move (selected_fields));
411403 }
412404 }
413405
414406 return {};
415407}
416408
417- Status PruneColumnVisitor::Visit (const ListType& type) {
409+ Status PruneColumnVisitor::Visit (const ListType& type,
410+ std::shared_ptr<const Type>& result) {
418411 const auto & element_field = type.fields ()[0 ];
419412
420- PruneColumnVisitor element_visitor (selected_ids_, select_full_types_);
421- ICEBERG_RETURN_UNEXPECTED (VisitTypeInline (*element_field.type (), &element_visitor));
413+ if (select_full_types_ and selected_ids_.contains (element_field.field_id ())) {
414+ result = std::make_shared<ListType>(type);
415+ return {};
416+ }
422417
423- auto element_result = element_visitor.GetResult ();
418+ std::shared_ptr<const Type> child_result;
419+ ICEBERG_RETURN_UNEXPECTED (VisitTypeInline (*element_field.type (), this , child_result));
424420
425421 if (selected_ids_.contains (element_field.field_id ())) {
426- if (select_full_types_) {
427- result_ = std::make_shared<ListType>(element_field);
428- } else if (element_field.type ()->type_id () == TypeId::kStruct ) {
429- ICEBERG_RETURN_UNEXPECTED (ProjectList (element_field, element_result));
422+ if (element_field.type ()->type_id () == TypeId::kStruct ) {
423+ ICEBERG_RETURN_UNEXPECTED (ProjectList (element_field, child_result, result));
430424 } else {
431425 if (!element_field.type ()->is_primitive ()) {
432426 return InvalidArgument (
433427 " Cannot explicitly project List or Map types, List element {} of type {} was "
434428 " selected" ,
435429 element_field.field_id (), element_field.name ());
436430 }
437- result_ = std::make_shared<ListType>(element_field);
431+ result = std::make_shared<ListType>(element_field);
438432 }
439- } else if (element_result ) {
440- ICEBERG_RETURN_UNEXPECTED (ProjectList (element_field, element_result ));
433+ } else if (child_result ) {
434+ ICEBERG_RETURN_UNEXPECTED (ProjectList (element_field, child_result, result ));
441435 }
442436
443437 return {};
444438}
445439
446- Status PruneColumnVisitor::Visit (const MapType& type) {
440+ Status PruneColumnVisitor::Visit (const MapType& type,
441+ std::shared_ptr<const Type>& result) {
447442 const auto & key_field = type.fields ()[0 ];
448443 const auto & value_field = type.fields ()[1 ];
449444
450- PruneColumnVisitor key_visitor (selected_ids_, select_full_types_);
451- ICEBERG_RETURN_UNEXPECTED (VisitTypeInline (*key_field.type (), &key_visitor));
452- auto key_result = key_visitor.GetResult ();
445+ if (select_full_types_ and selected_ids_.contains (value_field.field_id ())) {
446+ result = std::make_shared<MapType>(type);
447+ return {};
448+ }
453449
454- PruneColumnVisitor value_visitor (selected_ids_, select_full_types_);
455- ICEBERG_RETURN_UNEXPECTED (VisitTypeInline (*value_field.type (), &value_visitor));
456- auto value_result = value_visitor.GetResult ();
450+ std::shared_ptr<const Type> key_result;
451+ ICEBERG_RETURN_UNEXPECTED (VisitTypeInline (*key_field.type (), this , key_result));
452+
453+ std::shared_ptr<const Type> value_result;
454+ ICEBERG_RETURN_UNEXPECTED (VisitTypeInline (*value_field.type (), this , value_result));
457455
458456 if (selected_ids_.contains (value_field.field_id ())) {
459- if (select_full_types_) {
460- result_ = std::make_shared<MapType>(type);
461- } else if (value_field.type ()->type_id () == TypeId::kStruct ) {
462- ICEBERG_RETURN_UNEXPECTED (ProjectMap (key_field, value_field, value_result));
457+ if (value_field.type ()->type_id () == TypeId::kStruct ) {
458+ ICEBERG_RETURN_UNEXPECTED (ProjectMap (key_field, value_field, value_result, result));
463459 } else {
464460 if (!value_field.type ()->is_primitive ()) {
465461 return InvalidArgument (
466462 " Cannot explicitly project List or Map types, Map value {} of type {} was "
467463 " selected" ,
468464 value_field.field_id (), type.ToString ());
469465 }
470- result_ = std::make_shared<MapType>(type);
466+ result = std::make_shared<MapType>(type);
471467 }
472468 } else if (value_result) {
473- ICEBERG_RETURN_UNEXPECTED (ProjectMap (key_field, value_field, value_result));
469+ ICEBERG_RETURN_UNEXPECTED (ProjectMap (key_field, value_field, value_result, result ));
474470 } else if (selected_ids_.contains (key_field.field_id ())) {
475- result_ = std::make_shared<MapType>(type);
471+ result = std::make_shared<MapType>(type);
476472 }
477473
478474 return {};
479475}
480476
481- Status PruneColumnVisitor::Visit (const PrimitiveType& type) { return {}; }
477+ Status PruneColumnVisitor::Visit (const PrimitiveType& type,
478+ std::shared_ptr<const Type>& result) {
479+ return {};
480+ }
482481
483482Status PruneColumnVisitor::ProjectList (const SchemaField& element_field,
484- std::shared_ptr<const Type> element_result) {
485- if (!element_result) {
483+ std::shared_ptr<const Type>& child_result,
484+ std::shared_ptr<const Type>& result) {
485+ if (!child_result) {
486486 return InvalidArgument (" Cannot project a list when the element result is null" );
487487 }
488- if (element_field.type () == element_result ) {
489- result_ = std::make_shared<ListType>(element_field);
488+ if (element_field.type () == child_result ) {
489+ result = std::make_shared<ListType>(element_field);
490490 } else {
491- result_ = std::make_shared<ListType>(element_field.field_id (),
492- std::const_pointer_cast<Type>(element_result ),
493- element_field.optional ());
491+ result = std::make_shared<ListType>(element_field.field_id (),
492+ std::const_pointer_cast<Type>(child_result ),
493+ element_field.optional ());
494494 }
495495 return {};
496496}
497497
498498Status PruneColumnVisitor::ProjectMap (const SchemaField& key_field,
499499 const SchemaField& value_field,
500- std::shared_ptr<const Type> value_result) {
500+ std::shared_ptr<const Type>& value_result,
501+ std::shared_ptr<const Type>& result) {
501502 if (!value_result) {
502503 return InvalidArgument (" Attempted to project a map without a defined map value type" );
503504 }
504505 if (value_field.type () == value_result) {
505- result_ = std::make_shared<MapType>(key_field, value_field);
506+ result = std::make_shared<MapType>(key_field, value_field);
506507 } else {
507- result_ = std::make_shared<MapType>(
508+ result = std::make_shared<MapType>(
508509 key_field,
509510 SchemaField (value_field.field_id (), std::string (value_field.name ()),
510511 std::const_pointer_cast<Type>(value_result), value_field.optional ()));
0 commit comments