Skip to content

Commit 9f00684

Browse files
committed
Merge pull request #592 from tamird/forbid-interrupting-timeouts
Use `Thread.handle_interrupt` to protect `query`
2 parents 435b550 + 37da305 commit 9f00684

File tree

3 files changed

+44
-64
lines changed

3 files changed

+44
-64
lines changed

ext/mysql2/client.c

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -643,39 +643,26 @@ static VALUE rb_mysql_client_abandon_results(VALUE self) {
643643
* Query the database with +sql+, with optional +options+. For the possible
644644
* options, see @@default_query_options on the Mysql2::Client class.
645645
*/
646-
static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) {
646+
static VALUE rb_query(VALUE self, VALUE sql, VALUE current) {
647647
#ifndef _WIN32
648648
struct async_query_args async_args;
649649
#endif
650650
struct nogvl_send_query_args args;
651-
int async = 0;
652-
VALUE opts, current;
653-
#ifdef HAVE_RUBY_ENCODING_H
654-
rb_encoding *conn_enc;
655-
#endif
656651
GET_CLIENT(self);
657652

658653
REQUIRE_CONNECTED(wrapper);
659654
args.mysql = wrapper->client;
660655

661-
current = rb_hash_dup(rb_iv_get(self, "@query_options"));
662656
RB_GC_GUARD(current);
663657
Check_Type(current, T_HASH);
664658
rb_iv_set(self, "@current_query_options", current);
665659

666-
if (rb_scan_args(argc, argv, "11", &args.sql, &opts) == 2) {
667-
rb_funcall(current, intern_merge_bang, 1, opts);
668-
669-
if (rb_hash_aref(current, sym_async) == Qtrue) {
670-
async = 1;
671-
}
672-
}
673-
674-
Check_Type(args.sql, T_STRING);
660+
Check_Type(sql, T_STRING);
675661
#ifdef HAVE_RUBY_ENCODING_H
676-
conn_enc = rb_to_encoding(wrapper->encoding);
677662
/* ensure the string is in the encoding the connection is expecting */
678-
args.sql = rb_str_export_to_enc(args.sql, conn_enc);
663+
args.sql = rb_str_export_to_enc(sql, rb_to_encoding(wrapper->encoding));
664+
#else
665+
args.sql = sql;
679666
#endif
680667
args.sql_ptr = StringValuePtr(args.sql);
681668
args.sql_len = RSTRING_LEN(args.sql);
@@ -686,15 +673,15 @@ static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) {
686673
#ifndef _WIN32
687674
rb_rescue2(do_send_query, (VALUE)&args, disconnect_and_raise, self, rb_eException, (VALUE)0);
688675

689-
if (!async) {
676+
if (rb_hash_aref(current, sym_async) == Qtrue) {
677+
return Qnil;
678+
} else {
690679
async_args.fd = wrapper->client->net.fd;
691680
async_args.self = self;
692681

693682
rb_rescue2(do_query, (VALUE)&async_args, disconnect_and_raise, self, rb_eException, (VALUE)0);
694683

695684
return rb_mysql_client_async_result(self);
696-
} else {
697-
return Qnil;
698685
}
699686
#else
700687
do_send_query(&args);
@@ -1262,7 +1249,6 @@ void init_mysql2_client() {
12621249
rb_define_singleton_method(cMysql2Client, "escape", rb_mysql_client_escape, 1);
12631250

12641251
rb_define_method(cMysql2Client, "close", rb_mysql_client_close, 0);
1265-
rb_define_method(cMysql2Client, "query", rb_mysql_client_query, -1);
12661252
rb_define_method(cMysql2Client, "abandon_results!", rb_mysql_client_abandon_results, 0);
12671253
rb_define_method(cMysql2Client, "escape", rb_mysql_client_real_escape, 1);
12681254
rb_define_method(cMysql2Client, "info", rb_mysql_client_info, 0);
@@ -1297,6 +1283,7 @@ void init_mysql2_client() {
12971283
rb_define_private_method(cMysql2Client, "ssl_set", set_ssl_options, 5);
12981284
rb_define_private_method(cMysql2Client, "initialize_ext", initialize_ext, 0);
12991285
rb_define_private_method(cMysql2Client, "connect", rb_connect, 7);
1286+
rb_define_private_method(cMysql2Client, "_query", rb_query, 2);
13001287

13011288
sym_id = ID2SYM(rb_intern("id"));
13021289
sym_version = ID2SYM(rb_intern("version"));

lib/mysql2/client.rb

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,18 @@ def self.default_query_options
7474
@@default_query_options
7575
end
7676

77+
if Thread.respond_to?(:handle_interrupt)
78+
def query(sql, options = {})
79+
Thread.handle_interrupt(Timeout::ExitException => :never) do
80+
_query(sql, @query_options.merge(options))
81+
end
82+
end
83+
else
84+
def query(sql, options = {})
85+
_query(sql, @query_options.merge(options))
86+
end
87+
end
88+
7789
def query_info
7890
info = query_info_string
7991
return {} unless info

spec/mysql2/client_spec.rb

Lines changed: 23 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -461,59 +461,40 @@ def connect *args
461461
}.should raise_error(Mysql2::Error)
462462
end
463463

464-
it "should close the connection when an exception is raised" do
465-
begin
466-
Timeout.timeout(1, Timeout::Error) do
467-
@client.query("SELECT sleep(2)")
468-
end
469-
rescue Timeout::Error
470-
end
471464

472-
lambda {
473-
@client.query("SELECT 1")
474-
}.should raise_error(Mysql2::Error, 'closed MySQL connection')
465+
it 'should be impervious to connection-corrupting timeouts ' do
466+
pending('`Thread.handle_interrupt` is not defined') unless Thread.respond_to?(:handle_interrupt)
467+
# attempt to break the connection
468+
expect { Timeout.timeout(0.1) { @client.query('SELECT SLEEP(1)') } }.to raise_error(Timeout::Error)
469+
470+
# expect the connection to not be broken
471+
expect { @client.query('SELECT 1') }.to_not raise_error
475472
end
476473

477-
it "should handle Timeouts without leaving the connection hanging if reconnect is true" do
478-
client = Mysql2::Client.new(DatabaseCredentials['root'].merge(:reconnect => true))
479-
begin
480-
Timeout.timeout(1, Timeout::Error) do
481-
client.query("SELECT sleep(2)")
482-
end
483-
rescue Timeout::Error
474+
context 'when a non-standard exception class is raised' do
475+
it "should close the connection when an exception is raised" do
476+
expect { Timeout.timeout(0.1, ArgumentError) { @client.query('SELECT SLEEP(1)') } }.to raise_error(ArgumentError)
477+
expect { @client.query('SELECT 1') }.to raise_error(Mysql2::Error, 'closed MySQL connection')
484478
end
485479

486-
lambda {
487-
client.query("SELECT 1")
488-
}.should_not raise_error(Mysql2::Error)
489-
end
480+
it "should handle Timeouts without leaving the connection hanging if reconnect is true" do
481+
client = Mysql2::Client.new(DatabaseCredentials['root'].merge(:reconnect => true))
490482

491-
it "should handle Timeouts without leaving the connection hanging if reconnect is set to true after construction true" do
492-
client = Mysql2::Client.new(DatabaseCredentials['root'])
493-
begin
494-
Timeout.timeout(1, Timeout::Error) do
495-
client.query("SELECT sleep(2)")
496-
end
497-
rescue Timeout::Error
483+
expect { Timeout.timeout(0.1, ArgumentError) { client.query('SELECT SLEEP(1)') } }.to raise_error(ArgumentError)
484+
expect { client.query('SELECT 1') }.to_not raise_error
498485
end
499486

500-
lambda {
501-
client.query("SELECT 1")
502-
}.should raise_error(Mysql2::Error)
487+
it "should handle Timeouts without leaving the connection hanging if reconnect is set to true after construction true" do
488+
client = Mysql2::Client.new(DatabaseCredentials['root'])
503489

504-
client.reconnect = true
490+
expect { Timeout.timeout(0.1, ArgumentError) { client.query('SELECT SLEEP(1)') } }.to raise_error(ArgumentError)
491+
expect { client.query('SELECT 1') }.to raise_error(Mysql2::Error)
505492

506-
begin
507-
Timeout.timeout(1, Timeout::Error) do
508-
client.query("SELECT sleep(2)")
509-
end
510-
rescue Timeout::Error
511-
end
512-
513-
lambda {
514-
client.query("SELECT 1")
515-
}.should_not raise_error(Mysql2::Error)
493+
client.reconnect = true
516494

495+
expect { Timeout.timeout(0.1, ArgumentError) { client.query('SELECT SLEEP(1)') } }.to raise_error(ArgumentError)
496+
expect { client.query('SELECT 1') }.to_not raise_error
497+
end
517498
end
518499

519500
it "threaded queries should be supported" do

0 commit comments

Comments
 (0)