Skip to content

Commit 351edfe

Browse files
author
nullccxsy
committed
refactor: update PruneColumnVisitor to use shared pointers for result handling
- Modified the PruneColumnVisitor class to pass results as shared pointers, improving memory management and clarity. - Updated Visit methods for ListType, MapType, and StructType to accommodate the new result handling approach.
1 parent f7bcd93 commit 351edfe

File tree

1 file changed

+76
-75
lines changed

1 file changed

+76
-75
lines changed

src/iceberg/schema.cc

Lines changed: 76 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -261,21 +261,20 @@ class PruneColumnVisitor {
261261
public:
262262
explicit PruneColumnVisitor(const std::unordered_set<int32_t>& selected_ids,
263263
bool select_full_types = false);
264-
Status Visit(const ListType& type);
265-
Status Visit(const MapType& type);
266-
Status Visit(const StructType& type);
267-
Status Visit(const PrimitiveType& type);
268-
std::shared_ptr<const Type> GetResult() const;
269-
void SetResult(std::shared_ptr<const Type> result);
264+
Status Visit(const ListType& type, std::shared_ptr<const Type>& result);
265+
Status Visit(const MapType& type, std::shared_ptr<const Type>& result);
266+
Status Visit(const StructType& type, std::shared_ptr<const Type>& result);
267+
Status Visit(const PrimitiveType& type, std::shared_ptr<const Type>& result);
270268
Status ProjectList(const SchemaField& element,
271-
std::shared_ptr<const Type> element_result);
269+
std::shared_ptr<const Type>& child_result,
270+
std::shared_ptr<const Type>& result);
272271
Status ProjectMap(const SchemaField& key_field, const SchemaField& value_field,
273-
std::shared_ptr<const Type> value_result);
272+
std::shared_ptr<const Type>& value_result,
273+
std::shared_ptr<const Type>& result);
274274

275275
private:
276276
const std::unordered_set<int32_t>& selected_ids_;
277277
bool select_full_types_;
278-
std::shared_ptr<const Type> result_;
279278
};
280279

281280
Result<std::shared_ptr<const Schema>> Schema::select(
@@ -303,20 +302,19 @@ Result<std::shared_ptr<const Schema>> Schema::internalSelect(
303302
}
304303
}
305304

305+
std::shared_ptr<const Type> result;
306306
PruneColumnVisitor visitor(selected_ids, /*select_full_types=*/true);
307-
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*this, &visitor));
307+
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*this, &visitor, result));
308308

309-
auto projected_type = visitor.GetResult();
310-
if (!projected_type) {
309+
if (!result) {
311310
return std::make_shared<Schema>(std::vector<SchemaField>{}, schema_id_);
312311
}
313312

314-
if (projected_type->type_id() != TypeId::kStruct) {
313+
if (result->type_id() != TypeId::kStruct) {
315314
return InvalidSchema("Projected type must be a struct type");
316315
}
317316

318-
const auto& projected_struct =
319-
internal::checked_cast<const StructType&>(*projected_type);
317+
const auto& projected_struct = internal::checked_cast<const StructType&>(*result);
320318

321319
std::vector<SchemaField> fields_vec(projected_struct.fields().begin(),
322320
projected_struct.fields().end());
@@ -326,19 +324,19 @@ Result<std::shared_ptr<const Schema>> Schema::internalSelect(
326324
Result<std::shared_ptr<const Schema>> Schema::project(
327325
std::unordered_set<int32_t>& selected_ids) const {
328326
PruneColumnVisitor visitor(selected_ids, /*select_full_types=*/false);
329-
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*this, &visitor));
330327

331-
auto projected_type = visitor.GetResult();
332-
if (!projected_type) {
328+
std::shared_ptr<const Type> result;
329+
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*this, &visitor, result));
330+
331+
if (!result) {
333332
return std::make_shared<Schema>(std::vector<SchemaField>{}, schema_id_);
334333
}
335334

336-
if (projected_type->type_id() != TypeId::kStruct) {
335+
if (result->type_id() != TypeId::kStruct) {
337336
return InvalidSchema("Projected type must be a struct type");
338337
}
339338

340-
const auto& projected_struct =
341-
internal::checked_cast<const StructType&>(*projected_type);
339+
const auto& projected_struct = internal::checked_cast<const StructType&>(*result);
342340
std::vector<SchemaField> fields_vec(projected_struct.fields().begin(),
343341
projected_struct.fields().end());
344342
return std::make_shared<Schema>(std::move(fields_vec), schema_id_);
@@ -348,29 +346,23 @@ PruneColumnVisitor::PruneColumnVisitor(const std::unordered_set<int32_t>& select
348346
bool select_full_types)
349347
: selected_ids_(selected_ids), select_full_types_(select_full_types) {}
350348

351-
std::shared_ptr<const Type> PruneColumnVisitor::GetResult() const { return result_; }
352-
353-
void PruneColumnVisitor::SetResult(std::shared_ptr<const Type> result) {
354-
result_ = std::move(result);
355-
}
356-
357-
Status PruneColumnVisitor::Visit(const StructType& type) {
349+
Status PruneColumnVisitor::Visit(const StructType& type,
350+
std::shared_ptr<const Type>& result) {
358351
std::vector<std::shared_ptr<const Type>> selected_types;
359352
const auto& fields = type.fields();
360353
for (const auto& field : fields) {
361-
PruneColumnVisitor field_visitor(selected_ids_, select_full_types_);
362-
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*field.type(), &field_visitor));
363-
auto result = field_visitor.GetResult();
354+
std::shared_ptr<const Type> child_result;
355+
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*field.type(), this, child_result));
364356
if (selected_ids_.contains(field.field_id())) {
365357
// select
366358
if (select_full_types_) {
367359
selected_types.emplace_back(field.type());
368360
} else if (field.type()->type_id() == TypeId::kStruct) {
369361
// project(kstruct)
370-
if (!result) {
371-
result = std::make_shared<StructType>(std::vector<SchemaField>{});
362+
if (!child_result) {
363+
child_result = std::make_shared<StructType>(std::vector<SchemaField>{});
372364
}
373-
selected_types.emplace_back(std::move(result));
365+
selected_types.emplace_back(std::move(child_result));
374366
} else {
375367
// project(list, map, primitive)
376368
if (!field.type()->is_primitive()) {
@@ -381,9 +373,9 @@ Status PruneColumnVisitor::Visit(const StructType& type) {
381373
}
382374
selected_types.emplace_back(field.type());
383375
}
384-
} else if (result) {
376+
} else if (child_result) {
385377
// project, select
386-
selected_types.emplace_back(std::move(result));
378+
selected_types.emplace_back(std::move(child_result));
387379
} else {
388380
// project, select
389381
selected_types.emplace_back(nullptr);
@@ -404,107 +396,116 @@ Status PruneColumnVisitor::Visit(const StructType& type) {
404396
}
405397

406398
if (!selected_fields.empty()) {
407-
if (selected_fields.size() == fields.size() && same_types) {
408-
result_ = std::make_shared<StructType>(type);
399+
if (same_types && selected_fields.size() == fields.size()) {
400+
result = std::make_shared<StructType>(type);
409401
} else {
410-
result_ = std::make_shared<StructType>(std::move(selected_fields));
402+
result = std::make_shared<StructType>(std::move(selected_fields));
411403
}
412404
}
413405

414406
return {};
415407
}
416408

417-
Status PruneColumnVisitor::Visit(const ListType& type) {
409+
Status PruneColumnVisitor::Visit(const ListType& type,
410+
std::shared_ptr<const Type>& result) {
418411
const auto& element_field = type.fields()[0];
419412

420-
PruneColumnVisitor element_visitor(selected_ids_, select_full_types_);
421-
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*element_field.type(), &element_visitor));
413+
if (select_full_types_ and selected_ids_.contains(element_field.field_id())) {
414+
result = std::make_shared<ListType>(type);
415+
return {};
416+
}
422417

423-
auto element_result = element_visitor.GetResult();
418+
std::shared_ptr<const Type> child_result;
419+
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*element_field.type(), this, child_result));
424420

425421
if (selected_ids_.contains(element_field.field_id())) {
426-
if (select_full_types_) {
427-
result_ = std::make_shared<ListType>(element_field);
428-
} else if (element_field.type()->type_id() == TypeId::kStruct) {
429-
ICEBERG_RETURN_UNEXPECTED(ProjectList(element_field, element_result));
422+
if (element_field.type()->type_id() == TypeId::kStruct) {
423+
ICEBERG_RETURN_UNEXPECTED(ProjectList(element_field, child_result, result));
430424
} else {
431425
if (!element_field.type()->is_primitive()) {
432426
return InvalidArgument(
433427
"Cannot explicitly project List or Map types, List element {} of type {} was "
434428
"selected",
435429
element_field.field_id(), element_field.name());
436430
}
437-
result_ = std::make_shared<ListType>(element_field);
431+
result = std::make_shared<ListType>(element_field);
438432
}
439-
} else if (element_result) {
440-
ICEBERG_RETURN_UNEXPECTED(ProjectList(element_field, element_result));
433+
} else if (child_result) {
434+
ICEBERG_RETURN_UNEXPECTED(ProjectList(element_field, child_result, result));
441435
}
442436

443437
return {};
444438
}
445439

446-
Status PruneColumnVisitor::Visit(const MapType& type) {
440+
Status PruneColumnVisitor::Visit(const MapType& type,
441+
std::shared_ptr<const Type>& result) {
447442
const auto& key_field = type.fields()[0];
448443
const auto& value_field = type.fields()[1];
449444

450-
PruneColumnVisitor key_visitor(selected_ids_, select_full_types_);
451-
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*key_field.type(), &key_visitor));
452-
auto key_result = key_visitor.GetResult();
445+
if (select_full_types_ and selected_ids_.contains(value_field.field_id())) {
446+
result = std::make_shared<MapType>(type);
447+
return {};
448+
}
453449

454-
PruneColumnVisitor value_visitor(selected_ids_, select_full_types_);
455-
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*value_field.type(), &value_visitor));
456-
auto value_result = value_visitor.GetResult();
450+
std::shared_ptr<const Type> key_result;
451+
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*key_field.type(), this, key_result));
452+
453+
std::shared_ptr<const Type> value_result;
454+
ICEBERG_RETURN_UNEXPECTED(VisitTypeInline(*value_field.type(), this, value_result));
457455

458456
if (selected_ids_.contains(value_field.field_id())) {
459-
if (select_full_types_) {
460-
result_ = std::make_shared<MapType>(type);
461-
} else if (value_field.type()->type_id() == TypeId::kStruct) {
462-
ICEBERG_RETURN_UNEXPECTED(ProjectMap(key_field, value_field, value_result));
457+
if (value_field.type()->type_id() == TypeId::kStruct) {
458+
ICEBERG_RETURN_UNEXPECTED(ProjectMap(key_field, value_field, value_result, result));
463459
} else {
464460
if (!value_field.type()->is_primitive()) {
465461
return InvalidArgument(
466462
"Cannot explicitly project List or Map types, Map value {} of type {} was "
467463
"selected",
468464
value_field.field_id(), type.ToString());
469465
}
470-
result_ = std::make_shared<MapType>(type);
466+
result = std::make_shared<MapType>(type);
471467
}
472468
} else if (value_result) {
473-
ICEBERG_RETURN_UNEXPECTED(ProjectMap(key_field, value_field, value_result));
469+
ICEBERG_RETURN_UNEXPECTED(ProjectMap(key_field, value_field, value_result, result));
474470
} else if (selected_ids_.contains(key_field.field_id())) {
475-
result_ = std::make_shared<MapType>(type);
471+
result = std::make_shared<MapType>(type);
476472
}
477473

478474
return {};
479475
}
480476

481-
Status PruneColumnVisitor::Visit(const PrimitiveType& type) { return {}; }
477+
Status PruneColumnVisitor::Visit(const PrimitiveType& type,
478+
std::shared_ptr<const Type>& result) {
479+
return {};
480+
}
482481

483482
Status PruneColumnVisitor::ProjectList(const SchemaField& element_field,
484-
std::shared_ptr<const Type> element_result) {
485-
if (!element_result) {
483+
std::shared_ptr<const Type>& child_result,
484+
std::shared_ptr<const Type>& result) {
485+
if (!child_result) {
486486
return InvalidArgument("Cannot project a list when the element result is null");
487487
}
488-
if (element_field.type() == element_result) {
489-
result_ = std::make_shared<ListType>(element_field);
488+
if (element_field.type() == child_result) {
489+
result = std::make_shared<ListType>(element_field);
490490
} else {
491-
result_ = std::make_shared<ListType>(element_field.field_id(),
492-
std::const_pointer_cast<Type>(element_result),
493-
element_field.optional());
491+
result = std::make_shared<ListType>(element_field.field_id(),
492+
std::const_pointer_cast<Type>(child_result),
493+
element_field.optional());
494494
}
495495
return {};
496496
}
497497

498498
Status PruneColumnVisitor::ProjectMap(const SchemaField& key_field,
499499
const SchemaField& value_field,
500-
std::shared_ptr<const Type> value_result) {
500+
std::shared_ptr<const Type>& value_result,
501+
std::shared_ptr<const Type>& result) {
501502
if (!value_result) {
502503
return InvalidArgument("Attempted to project a map without a defined map value type");
503504
}
504505
if (value_field.type() == value_result) {
505-
result_ = std::make_shared<MapType>(key_field, value_field);
506+
result = std::make_shared<MapType>(key_field, value_field);
506507
} else {
507-
result_ = std::make_shared<MapType>(
508+
result = std::make_shared<MapType>(
508509
key_field,
509510
SchemaField(value_field.field_id(), std::string(value_field.name()),
510511
std::const_pointer_cast<Type>(value_result), value_field.optional()));

0 commit comments

Comments
 (0)