Skip to content

Commit 1c0e06c

Browse files
committed
Track the Statement refcount before freeing it
1 parent 9c76acf commit 1c0e06c

File tree

5 files changed

+34
-21
lines changed

5 files changed

+34
-21
lines changed

ext/mysql2/mysql2_ext.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ typedef unsigned int uint;
3939
#endif
4040

4141
#include <client.h>
42-
#include <result.h>
4342
#include <statement.h>
43+
#include <result.h>
4444
#include <infile.h>
4545

4646
#endif

ext/mysql2/result.c

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ static void rb_mysql_result_free_result(mysql2_result_wrapper * wrapper) {
9292
if (!wrapper) return;
9393

9494
if (wrapper->resultFreed != 1) {
95-
if (wrapper->stmt) {
96-
mysql_stmt_free_result(wrapper->stmt);
95+
if (wrapper->stmt_wrapper) {
96+
mysql_stmt_free_result(wrapper->stmt_wrapper->stmt);
9797

9898
if (wrapper->result_buffers) {
9999
unsigned int i;
@@ -125,6 +125,10 @@ static void rb_mysql_result_free(void *ptr) {
125125
decr_mysql2_client(wrapper->client_wrapper);
126126
}
127127

128+
if (wrapper->statement != Qnil) {
129+
decr_mysql2_stmt(wrapper->stmt_wrapper);
130+
}
131+
128132
xfree(wrapper);
129133
}
130134

@@ -341,23 +345,23 @@ static VALUE rb_mysql_result_fetch_row_stmt(VALUE self, MYSQL_FIELD * fields, co
341345
rb_mysql_result_alloc_result_buffers(self, fields);
342346
}
343347

344-
if(mysql_stmt_bind_result(wrapper->stmt, wrapper->result_buffers)) {
345-
rb_raise_mysql2_stmt_error2(wrapper->stmt
348+
if (mysql_stmt_bind_result(wrapper->stmt_wrapper->stmt, wrapper->result_buffers)) {
349+
rb_raise_mysql2_stmt_error2(wrapper->stmt_wrapper->stmt
346350
#ifdef HAVE_RUBY_ENCODING_H
347351
, conn_enc
348352
#endif
349353
);
350354
}
351355

352356
{
353-
switch((uintptr_t)rb_thread_call_without_gvl(nogvl_stmt_fetch, wrapper->stmt, RUBY_UBF_IO, 0)) {
357+
switch((uintptr_t)rb_thread_call_without_gvl(nogvl_stmt_fetch, wrapper->stmt_wrapper->stmt, RUBY_UBF_IO, 0)) {
354358
case 0:
355359
/* success */
356360
break;
357361

358362
case 1:
359363
/* error */
360-
rb_raise_mysql2_stmt_error2(wrapper->stmt
364+
rb_raise_mysql2_stmt_error2(wrapper->stmt_wrapper->stmt
361365
#ifdef HAVE_RUBY_ENCODING_H
362366
, conn_enc
363367
#endif
@@ -873,11 +877,11 @@ static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self) {
873877
rb_warn(":cache_rows is ignored if :stream is true");
874878
}
875879

876-
if (wrapper->stmt && !cacheRows && !wrapper->is_streaming) {
880+
if (wrapper->stmt_wrapper && !cacheRows && !wrapper->is_streaming) {
877881
rb_warn(":cache_rows is forced for prepared statements (if not streaming)");
878882
}
879883

880-
if (wrapper->stmt && !cast) {
884+
if (wrapper->stmt_wrapper && !cast) {
881885
rb_warn(":cast is forced for prepared statements");
882886
}
883887

@@ -903,7 +907,7 @@ static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self) {
903907
}
904908

905909
if (wrapper->lastRowProcessed == 0 && !wrapper->is_streaming) {
906-
wrapper->numberOfRows = wrapper->stmt ? mysql_stmt_num_rows(wrapper->stmt) : mysql_num_rows(wrapper->result);
910+
wrapper->numberOfRows = wrapper->stmt_wrapper ? mysql_stmt_num_rows(wrapper->stmt_wrapper->stmt) : mysql_num_rows(wrapper->result);
907911
if (wrapper->numberOfRows == 0) {
908912
wrapper->rows = rb_ary_new();
909913
return wrapper->rows;
@@ -921,7 +925,7 @@ static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self) {
921925
args.app_timezone = app_timezone;
922926
args.block_given = block;
923927

924-
if (wrapper->stmt) {
928+
if (wrapper->stmt_wrapper) {
925929
fetch_row_func = rb_mysql_result_fetch_row_stmt;
926930
} else {
927931
fetch_row_func = rb_mysql_result_fetch_row;
@@ -943,8 +947,8 @@ static VALUE rb_mysql_result_count(VALUE self) {
943947
return LONG2NUM(RARRAY_LEN(wrapper->rows));
944948
} else {
945949
/* MySQL returns an unsigned 64-bit long here */
946-
if(wrapper->stmt) {
947-
return ULL2NUM(mysql_stmt_num_rows(wrapper->stmt));
950+
if (wrapper->stmt_wrapper) {
951+
return ULL2NUM(mysql_stmt_num_rows(wrapper->stmt_wrapper->stmt));
948952
} else {
949953
return ULL2NUM(mysql_num_rows(wrapper->result));
950954
}
@@ -977,9 +981,10 @@ VALUE rb_mysql_result_to_obj(VALUE client, VALUE encoding, VALUE options, MYSQL_
977981
/* Keep a handle to the Statement to ensure it doesn't get garbage collected first */
978982
wrapper->statement = statement;
979983
if (statement != Qnil) {
980-
mysql_stmt_wrapper *stmt_wrapper = DATA_PTR(statement);
981-
wrapper->stmt = stmt_wrapper->stmt;
982-
stmt_wrapper->refcount++;
984+
wrapper->stmt_wrapper = DATA_PTR(statement);
985+
wrapper->stmt_wrapper->refcount++;
986+
} else {
987+
wrapper->stmt_wrapper = NULL;
983988
}
984989

985990
rb_obj_call_init(obj, 0, NULL);

ext/mysql2/result.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ typedef struct {
1717
char streamingComplete;
1818
char resultFreed;
1919
MYSQL_RES *result;
20-
MYSQL_STMT *stmt;
20+
mysql_stmt_wrapper *stmt_wrapper;
2121
mysql_client_wrapper *client_wrapper;
2222
/* statement result bind buffers */
2323
MYSQL_BIND *result_buffers;

ext/mysql2/statement.c

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,16 @@ static void rb_mysql_stmt_mark(void * ptr) {
1919

2020
static void rb_mysql_stmt_free(void * ptr) {
2121
mysql_stmt_wrapper* stmt_wrapper = (mysql_stmt_wrapper *)ptr;
22+
decr_mysql2_stmt(stmt_wrapper);
23+
}
2224

23-
mysql_stmt_close(stmt_wrapper->stmt);
25+
void decr_mysql2_stmt(mysql_stmt_wrapper *stmt_wrapper) {
26+
stmt_wrapper->refcount--;
2427

25-
xfree(ptr);
28+
if (stmt_wrapper->refcount == 0) {
29+
mysql_stmt_close(stmt_wrapper->stmt);
30+
xfree(stmt_wrapper);
31+
}
2632
}
2733

2834
VALUE rb_raise_mysql2_stmt_error2(MYSQL_STMT *stmt
@@ -103,6 +109,7 @@ VALUE rb_mysql_stmt_new(VALUE rb_client, VALUE sql) {
103109
rb_stmt = Data_Make_Struct(cMysql2Statement, mysql_stmt_wrapper, rb_mysql_stmt_mark, rb_mysql_stmt_free, stmt_wrapper);
104110
{
105111
stmt_wrapper->client = rb_client;
112+
stmt_wrapper->refcount = 0;
106113
stmt_wrapper->stmt = NULL;
107114
}
108115

ext/mysql2/statement.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33

44
extern VALUE cMysql2Statement;
55

6-
void init_mysql2_statement();
7-
86
typedef struct {
97
VALUE client;
108
MYSQL_STMT *stmt;
119
int refcount;
1210
} mysql_stmt_wrapper;
1311

12+
void init_mysql2_statement();
13+
void decr_mysql2_stmt(mysql_stmt_wrapper *stmt_wrapper);
14+
1415
VALUE rb_mysql_stmt_new(VALUE rb_client, VALUE sql);
1516
VALUE rb_raise_mysql2_stmt_error2(MYSQL_STMT *stmt
1617
#ifdef HAVE_RUBY_ENCODING_H

0 commit comments

Comments
 (0)