Skip to content

Commit f182795

Browse files
committed
Rename macro GetMysql2Result to GET_RESULT
Switch from DATA_PTR to Data_Get_Struct. This protects against calling methods that need the result wrapper on hand-built Mysql2::Result objects. They will raise a TypeError instead of segfaulting.
1 parent 3473aee commit f182795

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

ext/mysql2/result.c

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ static rb_encoding *binaryEncoding;
5050
#define MYSQL2_MIN_TIME 62171150401ULL
5151
#endif
5252

53+
#define GET_RESULT(obj) \
54+
mysql2_result_wrapper *wrapper; \
55+
Data_Get_Struct(self, mysql2_result_wrapper, wrapper);
56+
5357
static VALUE cMysql2Result;
5458
static VALUE cBigDecimal, cDate, cDateTime;
5559
static VALUE opt_decimal_zero, opt_float_zero, opt_time_year, opt_time_month, opt_utc_offset;
@@ -103,9 +107,8 @@ static void *nogvl_fetch_row(void *ptr) {
103107
}
104108

105109
static VALUE rb_mysql_result_fetch_field(VALUE self, unsigned int idx, short int symbolize_keys) {
106-
mysql2_result_wrapper * wrapper;
107110
VALUE rb_field;
108-
GetMysql2Result(self, wrapper);
111+
GET_RESULT(self);
109112

110113
if (wrapper->fields == Qnil) {
111114
wrapper->numberOfFields = mysql_num_fields(wrapper->result);
@@ -193,7 +196,6 @@ static unsigned int msec_char_to_uint(char *msec_char, size_t len)
193196

194197
static VALUE rb_mysql_result_fetch_row(VALUE self, ID db_timezone, ID app_timezone, int symbolizeKeys, int asArray, int castBool, int cast, MYSQL_FIELD * fields) {
195198
VALUE rowVal;
196-
mysql2_result_wrapper * wrapper;
197199
MYSQL_ROW row;
198200
unsigned int i = 0;
199201
unsigned long * fieldLengths;
@@ -202,7 +204,7 @@ static VALUE rb_mysql_result_fetch_row(VALUE self, ID db_timezone, ID app_timezo
202204
rb_encoding *default_internal_enc;
203205
rb_encoding *conn_enc;
204206
#endif
205-
GetMysql2Result(self, wrapper);
207+
GET_RESULT(self);
206208

207209
#ifdef HAVE_RUBY_ENCODING_H
208210
default_internal_enc = rb_default_internal_encoding();
@@ -413,12 +415,11 @@ static VALUE rb_mysql_result_fetch_row(VALUE self, ID db_timezone, ID app_timezo
413415
}
414416

415417
static VALUE rb_mysql_result_fetch_fields(VALUE self) {
416-
mysql2_result_wrapper * wrapper;
417418
unsigned int i = 0;
418419
short int symbolizeKeys = 0;
419420
VALUE defaults;
420421

421-
GetMysql2Result(self, wrapper);
422+
GET_RESULT(self);
422423

423424
defaults = rb_iv_get(self, "@query_options");
424425
Check_Type(defaults, T_HASH);
@@ -443,13 +444,12 @@ static VALUE rb_mysql_result_fetch_fields(VALUE self) {
443444
static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self) {
444445
VALUE defaults, opts, block;
445446
ID db_timezone, app_timezone, dbTz, appTz;
446-
mysql2_result_wrapper * wrapper;
447447
unsigned long i;
448448
const char * errstr;
449449
int symbolizeKeys, asArray, castBool, cacheRows, cast;
450450
MYSQL_FIELD * fields = NULL;
451451

452-
GetMysql2Result(self, wrapper);
452+
GET_RESULT(self);
453453

454454
defaults = rb_iv_get(self, "@query_options");
455455
Check_Type(defaults, T_HASH);
@@ -466,7 +466,7 @@ static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self) {
466466
cast = RTEST(rb_hash_aref(opts, sym_cast));
467467

468468
if (wrapper->is_streaming && cacheRows) {
469-
rb_warn("cacheRows is ignored if streaming is true");
469+
rb_warn(":cache_rows is ignored if :stream is true");
470470
}
471471

472472
dbTz = rb_hash_aref(opts, sym_database_timezone);
@@ -577,9 +577,8 @@ static VALUE rb_mysql_result_each(int argc, VALUE * argv, VALUE self) {
577577
}
578578

579579
static VALUE rb_mysql_result_count(VALUE self) {
580-
mysql2_result_wrapper *wrapper;
580+
GET_RESULT(self);
581581

582-
GetMysql2Result(self, wrapper);
583582
if (wrapper->is_streaming) {
584583
/* This is an unsigned long per result.h */
585584
return ULONG2NUM(wrapper->numberOfRows);
@@ -598,6 +597,7 @@ static VALUE rb_mysql_result_count(VALUE self) {
598597
VALUE rb_mysql_result_to_obj(VALUE client, VALUE encoding, VALUE options, MYSQL_RES *r) {
599598
VALUE obj;
600599
mysql2_result_wrapper * wrapper;
600+
601601
obj = Data_Make_Struct(cMysql2Result, mysql2_result_wrapper, rb_mysql_result_mark, rb_mysql_result_free, wrapper);
602602
wrapper->numberOfFields = 0;
603603
wrapper->numberOfRows = 0;

ext/mysql2/result.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,4 @@ typedef struct {
1919
mysql_client_wrapper *client_wrapper;
2020
} mysql2_result_wrapper;
2121

22-
#define GetMysql2Result(obj, sval) (sval = (mysql2_result_wrapper*)DATA_PTR(obj));
23-
2422
#endif

spec/mysql2/result_spec.rb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,14 @@
66
@result = @client.query "SELECT 1"
77
end
88

9+
it "should raise a TypeError exception when it doesn't wrap a result set" do
10+
r = Mysql2::Result.new
11+
expect { r.count }.to raise_error(TypeError)
12+
expect { r.fields }.to raise_error(TypeError)
13+
expect { r.size }.to raise_error(TypeError)
14+
expect { r.each }.to raise_error(TypeError)
15+
end
16+
917
it "should have included Enumerable" do
1018
Mysql2::Result.ancestors.include?(Enumerable).should be_true
1119
end

0 commit comments

Comments
 (0)