Skip to content

Commit 764815b

Browse files
sahil1105IsaacWarren
authored andcommitted
Update Bodo patch to include early casting changes in Parquet reader
1 parent 5ce1e3d commit 764815b

File tree

1 file changed

+101
-93
lines changed

1 file changed

+101
-93
lines changed

recipe/patches/0004-Bodo-Changes.patch

Lines changed: 101 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,103 @@
11
diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc
2-
index c17ba89be7..15bde86ba4 100644
2+
index 1f8b6cc488..5ad7a5f78b 100644
33
--- a/cpp/src/arrow/dataset/file_parquet.cc
44
+++ b/cpp/src/arrow/dataset/file_parquet.cc
5-
@@ -36,6 +36,7 @@
5+
@@ -26,16 +26,23 @@
6+
7+
#include "arrow/compute/cast.h"
8+
#include "arrow/compute/exec.h"
9+
+#include "arrow/dataset/dataset.h"
10+
#include "arrow/dataset/dataset_internal.h"
11+
#include "arrow/dataset/parquet_encryption_config.h"
12+
#include "arrow/dataset/scanner.h"
13+
#include "arrow/filesystem/path_util.h"
14+
+#include "arrow/memory_pool.h"
15+
+#include "arrow/record_batch.h"
16+
+#include "arrow/result.h"
17+
#include "arrow/table.h"
18+
+#include "arrow/type.h"
19+
+#include "arrow/type_fwd.h"
20+
#include "arrow/util/checked_cast.h"
21+
#include "arrow/util/future.h"
622
#include "arrow/util/iterator.h"
723
#include "arrow/util/logging.h"
824
#include "arrow/util/range.h"
925
+#include "arrow/util/thread_pool.h"
1026
#include "arrow/util/tracing_internal.h"
1127
#include "parquet/arrow/reader.h"
1228
#include "parquet/arrow/schema.h"
13-
@@ -630,10 +631,15 @@ Result<RecordBatchGenerator> ParquetFileFormat::ScanBatchesAsync(
29+
@@ -555,6 +562,60 @@ Future<std::shared_ptr<parquet::arrow::FileReader>> ParquetFileFormat::GetReader
30+
});
31+
}
32+
33+
+struct CastingGenerator {
34+
+ CastingGenerator(RecordBatchGenerator source, std::shared_ptr<Schema> final_schema,
35+
+ arrow::MemoryPool* pool = arrow::default_memory_pool())
36+
+ : source_(source),
37+
+ final_schema_(final_schema),
38+
+ exec_ctx(std::make_shared<compute::ExecContext>(pool)) {}
39+
+
40+
+ Future<std::shared_ptr<RecordBatch>> operator()() {
41+
+ return this->source_().Then(
42+
+ [this](const std::shared_ptr<RecordBatch>& next) -> std::shared_ptr<RecordBatch> {
43+
+ if (IsIterationEnd(next)) {
44+
+ return next;
45+
+ }
46+
+ std::vector<std::shared_ptr<::arrow::Array>> out_cols;
47+
+ std::vector<std::shared_ptr<arrow::Field>> out_schema_fields;
48+
+
49+
+ bool changed = false;
50+
+ for (const auto& field : this->final_schema_->fields()) {
51+
+ FieldRef field_ref = FieldRef(field->name());
52+
+ auto column_st = field_ref.GetOneOrNone(*next);
53+
+ std::shared_ptr<Array> column = column_st.ValueUnsafe();
54+
+ if (column) {
55+
+ if (!column->type()->Equals(field->type())) {
56+
+ // Referenced field was present but didn't have the expected type.
57+
+ auto converted_st =
58+
+ compute::Cast(column, field->type(), compute::CastOptions::Safe(),
59+
+ this->exec_ctx.get());
60+
+ auto converted = std::move(converted_st.ValueUnsafe());
61+
+ column = converted.make_array();
62+
+ changed = true;
63+
+ }
64+
+ out_cols.emplace_back(std::move(column));
65+
+ out_schema_fields.emplace_back(field->Copy());
66+
+ // XXX Do we need to handle the else case? What happens when the column
67+
+ // doesn't exist, e.g. all null or all the same value?
68+
+ }
69+
+ }
70+
+
71+
+ if (changed) {
72+
+ return RecordBatch::Make(
73+
+ std::make_shared<Schema>(std::move(out_schema_fields),
74+
+ next->schema()->metadata()),
75+
+ next->num_rows(), std::move(out_cols));
76+
+ } else {
77+
+ return next;
78+
+ }
79+
+ });
80+
+ }
81+
+
82+
+ RecordBatchGenerator source_;
83+
+ std::shared_ptr<Schema> final_schema_;
84+
+ std::shared_ptr<compute::ExecContext> exec_ctx;
85+
+};
86+
+
87+
struct SlicingGenerator {
88+
SlicingGenerator(RecordBatchGenerator source, int64_t batch_size)
89+
: state(std::make_shared<State>(source, batch_size)) {}
90+
@@ -617,6 +678,9 @@ Result<RecordBatchGenerator> ParquetFileFormat::ScanBatchesAsync(
91+
[this, options, parquet_fragment, pre_filtered,
92+
row_groups](const std::shared_ptr<parquet::arrow::FileReader>& reader) mutable
93+
-> Result<RecordBatchGenerator> {
94+
+ // Since we already do the batching through the SlicingGenerator, we don't need the
95+
+ // reader to batch its output.
96+
+ reader->set_batch_size(std::numeric_limits<int64_t>::max());
97+
// Ensure that parquet_fragment has FileMetaData
98+
RETURN_NOT_OK(parquet_fragment->EnsureCompleteMetadata(reader.get()));
99+
if (!pre_filtered) {
100+
@@ -633,10 +697,17 @@ Result<RecordBatchGenerator> ParquetFileFormat::ScanBatchesAsync(
14101
kParquetTypeName, options.get(), default_fragment_scan_options));
15102
int batch_readahead = options->batch_readahead;
16103
int64_t rows_to_readahead = batch_readahead * options->batch_size;
@@ -27,30 +114,16 @@ index c17ba89be7..15bde86ba4 100644
27114
+ ARROW_ASSIGN_OR_RAISE(auto generator, reader->GetRecordBatchGenerator(
28115
+ reader, row_groups, column_projection,
29116
+ cpu_executor, rows_to_readahead));
117+
+ generator =
118+
+ CastingGenerator(std::move(generator), options->dataset_schema, options->pool);
30119
RecordBatchGenerator sliced =
31120
SlicingGenerator(std::move(generator), options->batch_size);
32121
if (batch_readahead == 0) {
33122
diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc
34-
index 18981d1451..cdf5f586b4 100644
123+
index a856a792a2..5c10dfc6ac 100644
35124
--- a/cpp/src/arrow/dataset/scanner.cc
36125
+++ b/cpp/src/arrow/dataset/scanner.cc
37-
@@ -302,6 +302,7 @@ Result<EnumeratedRecordBatchGenerator> FragmentToBatches(
38-
{"arrow.dataset.fragment.type_name", fragment.value->type_name()},
39-
});
40-
#endif
41-
+ // This is the call site.
42-
ARROW_ASSIGN_OR_RAISE(auto batch_gen, fragment.value->ScanBatchesAsync(options));
43-
ArrayVector columns;
44-
for (const auto& field : options->dataset_schema->fields()) {
45-
@@ -327,6 +328,7 @@ Result<EnumeratedRecordBatchGenerator> FragmentToBatches(
46-
Result<AsyncGenerator<EnumeratedRecordBatchGenerator>> FragmentsToBatches(
47-
FragmentGenerator fragment_gen, const std::shared_ptr<ScanOptions>& options) {
48-
auto enumerated_fragment_gen = MakeEnumeratedGenerator(std::move(fragment_gen));
49-
+ // This is the call-site.
50-
auto batch_gen_gen =
51-
MakeMappedGenerator(std::move(enumerated_fragment_gen),
52-
[=](const Enumerated<std::shared_ptr<Fragment>>& fragment) {
53-
@@ -353,8 +355,10 @@ class OneShotFragment : public Fragment {
126+
@@ -355,8 +355,10 @@ class OneShotFragment : public Fragment {
54127
ARROW_ASSIGN_OR_RAISE(
55128
auto background_gen,
56129
MakeBackgroundGenerator(std::move(batch_it_), options->io_context.executor()));
@@ -63,7 +136,7 @@ index 18981d1451..cdf5f586b4 100644
63136
}
64137
std::string type_name() const override { return "one-shot"; }
65138

66-
@@ -380,7 +384,7 @@ Result<TaggedRecordBatchIterator> AsyncScanner::ScanBatches() {
139+
@@ -382,7 +384,7 @@ Result<TaggedRecordBatchIterator> AsyncScanner::ScanBatches() {
67140
[this](::arrow::internal::Executor* executor) {
68141
return ScanBatchesAsync(executor);
69142
},
@@ -72,7 +145,7 @@ index 18981d1451..cdf5f586b4 100644
72145
}
73146

74147
Result<EnumeratedRecordBatchIterator> AsyncScanner::ScanBatchesUnordered() {
75-
@@ -388,7 +392,7 @@ Result<EnumeratedRecordBatchIterator> AsyncScanner::ScanBatchesUnordered() {
148+
@@ -390,7 +392,7 @@ Result<EnumeratedRecordBatchIterator> AsyncScanner::ScanBatchesUnordered() {
76149
[this](::arrow::internal::Executor* executor) {
77150
return ScanBatchesUnorderedAsync(executor);
78151
},
@@ -81,7 +154,7 @@ index 18981d1451..cdf5f586b4 100644
81154
}
82155

83156
Result<std::shared_ptr<Table>> AsyncScanner::ToTable() {
84-
@@ -398,7 +402,7 @@ Result<std::shared_ptr<Table>> AsyncScanner::ToTable() {
157+
@@ -400,7 +402,7 @@ Result<std::shared_ptr<Table>> AsyncScanner::ToTable() {
85158
}
86159

87160
Result<EnumeratedRecordBatchGenerator> AsyncScanner::ScanBatchesUnorderedAsync() {
@@ -90,15 +163,7 @@ index 18981d1451..cdf5f586b4 100644
90163
/*sequence_fragments=*/false);
91164
}
92165

93-
@@ -443,6 +447,7 @@ Result<EnumeratedRecordBatchGenerator> AsyncScanner::ScanBatchesUnorderedAsync(
94-
scan_options_->projection.call()->options.get())
95-
->field_names;
96-
97-
+ // This is where the node is added to the plan.
98-
RETURN_NOT_OK(
99-
acero::Declaration::Sequence(
100-
{
101-
@@ -599,11 +604,12 @@ Result<std::shared_ptr<Table>> AsyncScanner::Head(int64_t num_rows) {
166+
@@ -601,7 +603,7 @@ Result<std::shared_ptr<Table>> AsyncScanner::Head(int64_t num_rows) {
102167
}
103168

104169
Result<TaggedRecordBatchGenerator> AsyncScanner::ScanBatchesAsync() {
@@ -107,12 +172,7 @@ index 18981d1451..cdf5f586b4 100644
107172
}
108173

109174
Result<TaggedRecordBatchGenerator> AsyncScanner::ScanBatchesAsync(
110-
Executor* cpu_executor) {
111-
+ // Is this part of the code path?
112-
ARROW_ASSIGN_OR_RAISE(
113-
auto unordered, ScanBatchesUnorderedAsync(cpu_executor, /*sequence_fragments=*/true,
114-
/*use_legacy_batching=*/true));
115-
@@ -775,7 +781,7 @@ Future<int64_t> AsyncScanner::CountRowsAsync(Executor* executor) {
175+
@@ -778,7 +780,7 @@ Future<int64_t> AsyncScanner::CountRowsAsync(Executor* executor) {
116176
}
117177

118178
Future<int64_t> AsyncScanner::CountRowsAsync() {
@@ -121,24 +181,8 @@ index 18981d1451..cdf5f586b4 100644
121181
}
122182

123183
Result<int64_t> AsyncScanner::CountRows() {
124-
@@ -999,6 +1005,7 @@ Result<acero::ExecNode*> MakeScanNode(acero::ExecPlan* plan,
125-
ARROW_ASSIGN_OR_RAISE(auto fragments_vec, fragments_it.ToVector());
126-
auto fragment_gen = MakeVectorGenerator(std::move(fragments_vec));
127-
128-
+ // This is the call site.
129-
ARROW_ASSIGN_OR_RAISE(auto batch_gen_gen,
130-
FragmentsToBatches(std::move(fragment_gen), scan_options));
131-
132-
@@ -1168,6 +1175,7 @@ Result<acero::ExecNode*> MakeOrderedSinkNode(acero::ExecPlan* plan,
133-
134-
namespace internal {
135-
void InitializeScanner(arrow::acero::ExecFactoryRegistry* registry) {
136-
+ // This is where it's registered.
137-
DCHECK_OK(registry->AddFactory("scan", MakeScanNode));
138-
DCHECK_OK(registry->AddFactory("ordered_sink", MakeOrderedSinkNode));
139-
DCHECK_OK(registry->AddFactory("augmented_project", MakeAugmentedProjectNode));
140184
diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h
141-
index 4479158ff2..301cdc0517 100644
185+
index d2de267897..1c605c1bf2 100644
142186
--- a/cpp/src/arrow/dataset/scanner.h
143187
+++ b/cpp/src/arrow/dataset/scanner.h
144188
@@ -107,6 +107,11 @@ struct ARROW_DS_EXPORT ScanOptions {
@@ -153,7 +197,7 @@ index 4479158ff2..301cdc0517 100644
153197
/// If true the scanner will scan in parallel
154198
///
155199
/// Note: If true, this will use threads from both the cpu_executor and the
156-
@@ -437,6 +442,11 @@ class ARROW_DS_EXPORT Scanner {
200+
@@ -442,6 +447,11 @@ class ARROW_DS_EXPORT Scanner {
157201
TaggedRecordBatchIterator scan);
158202

159203
const std::shared_ptr<ScanOptions> scan_options_;
@@ -216,39 +260,3 @@ index 44b1e227b0..218edc60ca 100644
216260
}
217261

218262
} // namespace internal
219-
diff --git a/cpp/src/parquet/arrow/reader.cc b/cpp/src/parquet/arrow/reader.cc
220-
index d6ad7c25bc..5ade5bb747 100644
221-
--- a/cpp/src/parquet/arrow/reader.cc
222-
+++ b/cpp/src/parquet/arrow/reader.cc
223-
@@ -1153,6 +1153,7 @@ class RowGroupGenerator {
224-
const int row_group, const std::vector<int>& column_indices) {
225-
// Skips bound checks/pre-buffering, since we've done that already
226-
const int64_t batch_size = self->properties().batch_size();
227-
+ // This the main location.
228-
return self->DecodeRowGroups(self, {row_group}, column_indices, cpu_executor)
229-
.Then([batch_size](const std::shared_ptr<Table>& table)
230-
-> ::arrow::Result<RecordBatchGenerator> {
231-
@@ -1190,6 +1191,7 @@ FileReaderImpl::GetRecordBatchGenerator(std::shared_ptr<FileReader> reader,
232-
reader_properties_.cache_options());
233-
END_PARQUET_CATCH_EXCEPTIONS
234-
}
235-
+ // This is where it's created it seems.
236-
::arrow::AsyncGenerator<RowGroupGenerator::RecordBatchGenerator> row_group_generator =
237-
RowGroupGenerator(::arrow::internal::checked_pointer_cast<FileReaderImpl>(reader),
238-
cpu_executor, row_group_indices, column_indices,
239-
@@ -1228,6 +1230,7 @@ Status FileReaderImpl::ReadRowGroups(const std::vector<int>& row_groups,
240-
END_PARQUET_CATCH_EXCEPTIONS
241-
}
242-
243-
+ // This is another call site (might not be called by our use case).
244-
auto fut = DecodeRowGroups(/*self=*/nullptr, row_groups, column_indices,
245-
/*cpu_executor=*/nullptr);
246-
ARROW_ASSIGN_OR_RAISE(*out, fut.MoveResult());
247-
@@ -1249,6 +1252,7 @@ Future<std::shared_ptr<Table>> FileReaderImpl::DecodeRowGroups(
248-
std::shared_ptr<ColumnReaderImpl> reader)
249-
-> ::arrow::Result<std::shared_ptr<::arrow::ChunkedArray>> {
250-
std::shared_ptr<::arrow::ChunkedArray> column;
251-
+ // This is the most likely place for invocation.
252-
RETURN_NOT_OK(ReadColumn(static_cast<int>(i), row_groups, reader.get(), &column));
253-
return column;
254-
};

0 commit comments

Comments
 (0)