diff --git a/ext/mysql2/client.c b/ext/mysql2/client.c index 259d25ab7..5d6b7555f 100644 --- a/ext/mysql2/client.c +++ b/ext/mysql2/client.c @@ -38,10 +38,12 @@ static ID intern_brackets, intern_merge, intern_merge_bang, intern_new_with_args #define CONNECTED(wrapper) (wrapper->client->net.pvio != NULL && wrapper->client->net.fd != -1 && VIO_IS_CONNECTED(wrapper)) #endif +#define MYSQL_CLIENT_NOT_CONNECTED_STR "MySQL client is not connected" + #define REQUIRE_CONNECTED(wrapper) \ REQUIRE_INITIALIZED(wrapper) \ if (!CONNECTED(wrapper) && !wrapper->reconnect_enabled) { \ - rb_raise(cMysql2Error, "MySQL client is not connected"); \ + rb_raise(cMysql2Error, MYSQL_CLIENT_NOT_CONNECTED_STR); \ } #define REQUIRE_NOT_CONNECTED(wrapper) \ @@ -50,6 +52,17 @@ static ID intern_brackets, intern_merge, intern_merge_bang, intern_new_with_args rb_raise(cMysql2Error, "MySQL connection is already open"); \ } +/* + * assert that we've connected at least once by using + * `client->server_version`, which is a string that is initialized to the char* + * server name once the client has connected + */ +#define REQUIRE_CONNECTED_ONCE(wrapper) \ + REQUIRE_INITIALIZED(wrapper) \ + if (!wrapper->client->server_version) { \ + rb_raise(cMysql2Error, MYSQL_CLIENT_NOT_CONNECTED_STR); \ + } + /* * compatability with mysql-connector-c, where LIBMYSQL_VERSION is the correct * variable to use, but MYSQL_SERVER_VERSION gives the correct numbers when @@ -819,7 +832,7 @@ static VALUE rb_mysql_client_real_escape(VALUE self, VALUE str) { rb_encoding *conn_enc; GET_CLIENT(self); - REQUIRE_CONNECTED(wrapper); + REQUIRE_CONNECTED_ONCE(wrapper); Check_Type(str, T_STRING); default_internal_enc = rb_default_internal_encoding(); conn_enc = rb_to_encoding(wrapper->encoding); diff --git a/spec/mysql2/client_spec.rb b/spec/mysql2/client_spec.rb index e6a6cc6b2..cd31933a6 100644 --- a/spec/mysql2/client_spec.rb +++ b/spec/mysql2/client_spec.rb @@ -887,11 +887,11 @@ def run_gc end.not_to raise_error end - it "should require an open connection" do + it "should not require an open connection" do @client.close expect do @client.escape "" - end.to raise_error(Mysql2::Error) + end.not_to raise_error end context 'when mysql encoding is not utf8' do