3030#include < limits>
3131#include < memory>
3232#include < string>
33+ #include < string_view>
3334#include < utility>
3435#include < vector>
3536
@@ -219,55 +220,89 @@ void TupleReader::Release() {
219220 row_id_ = -1 ;
220221}
221222
223+ // Instead of directly exporting the TupleReader, which is tied to the
224+ // lifetime of the Statement, we export a weak_ptr reference instead. That
225+ // way if the user accidentally closes the Statement before the
226+ // ArrowArrayStream, we can avoid a crash.
227+ // See https://github.com/apache/arrow-adbc/issues/2629
228+ struct ExportedTupleReader {
229+ std::weak_ptr<TupleReader> self;
230+ };
231+
222232void TupleReader::ExportTo (struct ArrowArrayStream * stream) {
223233 stream->get_schema = &GetSchemaTrampoline;
224234 stream->get_next = &GetNextTrampoline;
225235 stream->get_last_error = &GetLastErrorTrampoline;
226236 stream->release = &ReleaseTrampoline;
227- stream->private_data = this ;
237+ stream->private_data = new ExportedTupleReader{ weak_from_this ()} ;
228238}
229239
230- const struct AdbcError * TupleReader::ErrorFromArrayStream (struct ArrowArrayStream * stream ,
240+ const struct AdbcError * TupleReader::ErrorFromArrayStream (struct ArrowArrayStream * self ,
231241 AdbcStatusCode* status) {
232- if (!stream ->private_data || stream ->release != &ReleaseTrampoline) {
242+ if (!self ->private_data || self ->release != &ReleaseTrampoline) {
233243 return nullptr ;
234244 }
235245
236- TupleReader* reader = static_cast <TupleReader*>(stream->private_data );
237- if (status) {
238- *status = reader->status_ ;
246+ auto * wrapper = static_cast <ExportedTupleReader*>(self->private_data );
247+ auto maybe_reader = wrapper->self .lock ();
248+ if (maybe_reader) {
249+ if (status) {
250+ *status = maybe_reader->status_ ;
251+ }
252+ return &maybe_reader->error_ ;
239253 }
240- return &reader-> error_ ;
254+ return nullptr ;
241255}
242256
243257int TupleReader::GetSchemaTrampoline (struct ArrowArrayStream * self,
244258 struct ArrowSchema * out) {
245259 if (!self || !self->private_data ) return EINVAL;
246260
247- TupleReader* reader = static_cast <TupleReader*>(self->private_data );
248- return reader->GetSchema (out);
261+ auto * wrapper = static_cast <ExportedTupleReader*>(self->private_data );
262+ auto maybe_reader = wrapper->self .lock ();
263+ if (maybe_reader) {
264+ return maybe_reader->GetSchema (out);
265+ }
266+ // statement was closed or reader was otherwise invalidated
267+ return EINVAL;
249268}
250269
251270int TupleReader::GetNextTrampoline (struct ArrowArrayStream * self,
252271 struct ArrowArray * out) {
253272 if (!self || !self->private_data ) return EINVAL;
254273
255- TupleReader* reader = static_cast <TupleReader*>(self->private_data );
256- return reader->GetNext (out);
274+ auto * wrapper = static_cast <ExportedTupleReader*>(self->private_data );
275+ auto maybe_reader = wrapper->self .lock ();
276+ if (maybe_reader) {
277+ return maybe_reader->GetNext (out);
278+ }
279+ // statement was closed or reader was otherwise invalidated
280+ return EINVAL;
257281}
258282
259283const char * TupleReader::GetLastErrorTrampoline (struct ArrowArrayStream * self) {
260284 if (!self || !self->private_data ) return nullptr ;
285+ constexpr std::string_view kReaderInvalidated =
286+ " [libpq] Reader invalidated (statement or reader was closed)" ;
261287
262- TupleReader* reader = static_cast <TupleReader*>(self->private_data );
263- return reader->last_error ();
288+ auto * wrapper = static_cast <ExportedTupleReader*>(self->private_data );
289+ auto maybe_reader = wrapper->self .lock ();
290+ if (maybe_reader) {
291+ return maybe_reader->last_error ();
292+ }
293+ // statement was closed or reader was otherwise invalidated
294+ return kReaderInvalidated .data ();
264295}
265296
266297void TupleReader::ReleaseTrampoline (struct ArrowArrayStream * self) {
267298 if (!self || !self->private_data ) return ;
268299
269- TupleReader* reader = static_cast <TupleReader*>(self->private_data );
270- reader->Release ();
300+ auto * wrapper = static_cast <ExportedTupleReader*>(self->private_data );
301+ auto maybe_reader = wrapper->self .lock ();
302+ if (maybe_reader) {
303+ maybe_reader->Release ();
304+ }
305+ delete wrapper;
271306 self->private_data = nullptr ;
272307 self->release = nullptr ;
273308}
@@ -281,7 +316,7 @@ AdbcStatusCode PostgresStatement::New(struct AdbcConnection* connection,
281316 connection_ =
282317 *reinterpret_cast <std::shared_ptr<PostgresConnection>*>(connection->private_data );
283318 type_resolver_ = connection_->type_resolver ();
284- reader_. conn_ = connection_-> conn ();
319+ ClearResult ();
285320 return ADBC_STATUS_OK;
286321}
287322
@@ -514,24 +549,24 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream,
514549 }
515550
516551 struct ArrowError na_error;
517- reader_. copy_reader_ = std::make_unique<PostgresCopyStreamReader>();
518- CHECK_NA (INTERNAL, reader_. copy_reader_ ->Init (root_type), error);
552+ reader_-> copy_reader_ = std::make_unique<PostgresCopyStreamReader>();
553+ CHECK_NA (INTERNAL, reader_-> copy_reader_ ->Init (root_type), error);
519554 CHECK_NA_DETAIL (INTERNAL,
520- reader_. copy_reader_ ->InferOutputSchema (
555+ reader_-> copy_reader_ ->InferOutputSchema (
521556 std::string (connection_->VendorName ()), &na_error),
522557 &na_error, error);
523558
524- CHECK_NA_DETAIL (INTERNAL, reader_. copy_reader_ ->InitFieldReaders (&na_error), &na_error,
559+ CHECK_NA_DETAIL (INTERNAL, reader_-> copy_reader_ ->InitFieldReaders (&na_error), &na_error,
525560 error);
526561
527562 // Execute the COPY query
528563 RAISE_STATUS (error, helper.ExecuteCopy ());
529564
530565 // We need the PQresult back for the reader
531- reader_. result_ = helper.ReleaseResult ();
566+ reader_-> result_ = helper.ReleaseResult ();
532567
533568 // Export to stream
534- reader_. ExportTo (stream);
569+ reader_-> ExportTo (stream);
535570 if (rows_affected) *rows_affected = -1 ;
536571 return ADBC_STATUS_OK;
537572}
@@ -674,7 +709,7 @@ AdbcStatusCode PostgresStatement::GetOption(const char* key, char* value, size_t
674709 break ;
675710 }
676711 } else if (std::strcmp (key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0 ) {
677- result = std::to_string (reader_. batch_size_hint_bytes_ );
712+ result = std::to_string (reader_-> batch_size_hint_bytes_ );
678713 } else if (std::strcmp (key, ADBC_POSTGRESQL_OPTION_USE_COPY) == 0 ) {
679714 if (UseCopy ()) {
680715 result = " true" ;
@@ -710,7 +745,7 @@ AdbcStatusCode PostgresStatement::GetOptionInt(const char* key, int64_t* value,
710745 struct AdbcError * error) {
711746 std::string result;
712747 if (std::strcmp (key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0 ) {
713- *value = reader_. batch_size_hint_bytes_ ;
748+ *value = reader_-> batch_size_hint_bytes_ ;
714749 return ADBC_STATUS_OK;
715750 }
716751 SetError (error, " [libpq] Unknown statement option '%s'" , key);
@@ -799,7 +834,7 @@ AdbcStatusCode PostgresStatement::SetOption(const char* key, const char* value,
799834 return ADBC_STATUS_INVALID_ARGUMENT;
800835 }
801836
802- this ->reader_ . batch_size_hint_bytes_ = int_value;
837+ this ->batch_size_hint_bytes_ = this -> reader_ -> batch_size_hint_bytes_ = int_value;
803838 } else if (std::strcmp (key, ADBC_POSTGRESQL_OPTION_USE_COPY) == 0 ) {
804839 if (std::strcmp (value, ADBC_OPTION_VALUE_ENABLED) == 0 ) {
805840 use_copy_ = true ;
@@ -836,7 +871,7 @@ AdbcStatusCode PostgresStatement::SetOptionInt(const char* key, int64_t value,
836871 return ADBC_STATUS_INVALID_ARGUMENT;
837872 }
838873
839- this ->reader_ . batch_size_hint_bytes_ = value;
874+ this ->batch_size_hint_bytes_ = this -> reader_ -> batch_size_hint_bytes_ = value;
840875 return ADBC_STATUS_OK;
841876 }
842877 SetError (error, " [libpq] Unknown statement option '%s'" , key);
@@ -845,7 +880,9 @@ AdbcStatusCode PostgresStatement::SetOptionInt(const char* key, int64_t value,
845880
846881void PostgresStatement::ClearResult () {
847882 // TODO: we may want to synchronize here for safety
848- reader_.Release ();
883+ if (reader_) reader_->Release ();
884+ reader_ = std::make_shared<TupleReader>(connection_->conn ());
885+ reader_->batch_size_hint_bytes_ = batch_size_hint_bytes_;
849886}
850887
851888int PostgresStatement::UseCopy () {
0 commit comments