diff --git a/client.cpp b/client.cpp index a088d47..dece8ca 100755 --- a/client.cpp +++ b/client.cpp @@ -40,10 +40,9 @@ #include #endif -#ifdef HAVE_ASSERT_H -#include -#endif +#include +#include #include #include #include @@ -59,7 +58,14 @@ bool client::setup_client(benchmark_config *config, abstract_protocol *protocol, unsigned long long total_num_of_clients = config->clients*config->threads; // create main connection - shard_connection* conn = new shard_connection(m_connections.size(), this, m_config, m_event_base, protocol); + unsigned int thread_id = 0; // TODO: set actual thread id if available + unsigned int client_index = m_connections.size(); + unsigned int num_clients_per_thread = config->clients; + unsigned int conn_id = thread_id * num_clients_per_thread + client_index; + shard_connection* conn = new shard_connection( + client_index, this, m_config, m_event_base, protocol, + conn_id + ); m_connections.push_back(conn); m_obj_gen = objgen->clone(); @@ -99,7 +105,7 @@ bool client::setup_client(benchmark_config *config, abstract_protocol *protocol, return true; } -client::client(client_group* group) : +client::client(client_group* group, unsigned int conn_id) : m_event_base(NULL), m_initialized(false), m_end_set(false), m_config(NULL), m_obj_gen(NULL), m_stats(group->get_config()), m_reqs_processed(0), m_reqs_generated(0), m_set_ratio_count(0), m_get_ratio_count(0), @@ -108,16 +114,21 @@ client::client(client_group* group) : { m_event_base = group->get_event_base(); + // Initialize conn_id string and value with prefix + m_conn_id_str = "user" + std::to_string(conn_id); + m_conn_id_value = m_conn_id_str.c_str(); + m_conn_id_value_len = m_conn_id_str.length(); + if (!setup_client(group->get_config(), group->get_protocol(), group->get_obj_gen())) { return; } - benchmark_debug_log("new client %p successfully set up.\n", this); + benchmark_debug_log("new client %p successfully set up with conn_id: %s.\n", this, m_conn_id_value); m_initialized = true; } client::client(struct event_base *event_base, benchmark_config *config, - abstract_protocol *protocol, object_generator *obj_gen) : + abstract_protocol *protocol, object_generator *obj_gen, unsigned int conn_id) : m_event_base(NULL), m_initialized(false), m_end_set(false), m_config(NULL), m_obj_gen(NULL), m_stats(config), m_reqs_processed(0), m_reqs_generated(0), m_set_ratio_count(0), m_get_ratio_count(0), @@ -126,11 +137,16 @@ client::client(struct event_base *event_base, benchmark_config *config, { m_event_base = event_base; + // Initialize conn_id string and value + m_conn_id_str = std::to_string(conn_id); + m_conn_id_value = m_conn_id_str.c_str(); + m_conn_id_value_len = m_conn_id_str.length(); + if (!setup_client(config, protocol, obj_gen)) { return; } - benchmark_debug_log("new client %p successfully set up.\n", this); + benchmark_debug_log("new client %p successfully set up with conn_id: %s.\n", this, m_conn_id_value); m_initialized = true; } @@ -273,7 +289,11 @@ bool client::create_arbitrary_request(unsigned int command_index, struct timeval const arbitrary_command& cmd = get_arbitrary_command(command_index); - benchmark_debug_log("%s: %s:\n", m_connections[conn_id]->get_readable_id(), cmd.command.c_str()); + benchmark_debug_log("%s: %s", m_connections[conn_id]->get_readable_id(), cmd.command.c_str()); + + // Build final command string for debug output + std::string final_command = cmd.command; + bool has_substitutions = false; for (unsigned int i = 0; i < cmd.command_args.size(); i++) { const command_arg* arg = &cmd.command_args[i]; @@ -293,9 +313,32 @@ bool client::create_arbitrary_request(unsigned int command_index, struct timeval assert(value_len > 0); cmd_size += m_connections[conn_id]->send_arbitrary_command(arg, value, value_len); + } else if (arg->type == conn_id_type) { + // Replace __conn_id__ placeholder with actual connection ID + std::string substituted_arg = arg->data; + size_t pos = substituted_arg.find(CONN_PLACEHOLDER); + if (pos != std::string::npos) { + substituted_arg.replace(pos, strlen(CONN_PLACEHOLDER), m_conn_id_value); + has_substitutions = true; + } + + cmd_size += m_connections[conn_id]->send_arbitrary_command(arg, substituted_arg.c_str(), substituted_arg.length()); + + // Replace placeholder in final command string for debug output + pos = final_command.find(CONN_PLACEHOLDER); + if (pos != std::string::npos) { + final_command.replace(pos, strlen(CONN_PLACEHOLDER), m_conn_id_value); + } } } + // Show final command if substitutions were made + if (has_substitutions) { + benchmark_debug_log(" -> %s\n", final_command.c_str()); + } else { + benchmark_debug_log("\n"); + } + m_connections[conn_id]->send_arbitrary_command_end(command_index, ×tamp, cmd_size); return true; } @@ -581,8 +624,8 @@ bool verify_client::finished(void) /////////////////////////////////////////////////////////////////////////// -client_group::client_group(benchmark_config* config, abstract_protocol *protocol, object_generator* obj_gen) : - m_base(NULL), m_config(config), m_protocol(protocol), m_obj_gen(obj_gen) +client_group::client_group(benchmark_config* config, abstract_protocol *protocol, object_generator* obj_gen, unsigned int thread_id) : + m_base(NULL), m_config(config), m_protocol(protocol), m_obj_gen(obj_gen), m_thread_id(thread_id) { m_base = event_base_new(); assert(m_base != NULL); @@ -608,11 +651,12 @@ int client_group::create_clients(int num) { for (int i = 0; i < num; i++) { client* c; + unsigned int conn_id = m_thread_id * num + i + 1; if (m_config->cluster_mode) - c = new cluster_client(this); + c = new cluster_client(this, conn_id); else - c = new client(this); + c = new client(this, conn_id); assert(c != NULL); diff --git a/client.h b/client.h index 6f599a4..f704805 100755 --- a/client.h +++ b/client.h @@ -63,6 +63,9 @@ class client : public connections_manager { // test related benchmark_config* m_config; object_generator* m_obj_gen; + std::string m_conn_id_str; + const char* m_conn_id_value; + unsigned int m_conn_id_value_len; run_stats m_stats; unsigned long long m_reqs_processed; // requests processed (responses received) @@ -78,13 +81,14 @@ class client : public connections_manager { keylist *m_keylist; // used to construct multi commands public: - client(client_group* group); - client(struct event_base *event_base, benchmark_config *config, abstract_protocol *protocol, object_generator *obj_gen); + client(client_group* group, unsigned int conn_id = 0); + client(struct event_base *event_base, benchmark_config *config, abstract_protocol *protocol, object_generator *obj_gen, unsigned int conn_id = 0); virtual ~client(); bool setup_client(benchmark_config *config, abstract_protocol *protocol, object_generator *obj_gen); int prepare(void); bool initialized(void); run_stats* get_stats(void) { return &m_stats; } + const char* get_conn_id_value(void) { return m_conn_id_value; } virtual get_key_response get_key_for_conn(unsigned int command_index, unsigned int conn_id, unsigned long long* key_index); virtual bool create_arbitrary_request(unsigned int command_index, struct timeval& timestamp, unsigned int conn_id); @@ -203,8 +207,10 @@ class client_group { abstract_protocol* m_protocol; object_generator* m_obj_gen; std::vector m_clients; +protected: + unsigned int m_thread_id; public: - client_group(benchmark_config *cfg, abstract_protocol *protocol, object_generator* obj_gen); + client_group(benchmark_config *cfg, abstract_protocol *protocol, object_generator* obj_gen, unsigned int thread_id); ~client_group(); int create_clients(int count); diff --git a/cluster_client.cpp b/cluster_client.cpp index 10065bc..6584779 100644 --- a/cluster_client.cpp +++ b/cluster_client.cpp @@ -1,3 +1,4 @@ +#include /* * Copyright (C) 2011-2017 Redis Labs Ltd. * @@ -108,7 +109,7 @@ static uint32_t calc_hslot_crc16_cluster(const char *str, size_t length) /////////////////////////////////////////////////////////////////////////////////////////////////////// -cluster_client::cluster_client(client_group* group) : client(group) +cluster_client::cluster_client(client_group* group, unsigned int conn_id) : client(group, conn_id) { } @@ -159,9 +160,11 @@ void cluster_client::disconnect(void) } shard_connection* cluster_client::create_shard_connection(abstract_protocol* abs_protocol) { - shard_connection* sc = new shard_connection(m_connections.size(), this, - m_config, m_event_base, - abs_protocol); + unsigned int conn_id = m_connections.size(); + shard_connection* sc = new shard_connection( + conn_id, this, m_config, m_event_base, abs_protocol, + conn_id + ); assert(sc != NULL); m_connections.push_back(sc); diff --git a/cluster_client.h b/cluster_client.h index c792f67..1c244c2 100644 --- a/cluster_client.h +++ b/cluster_client.h @@ -43,7 +43,7 @@ class cluster_client : public client { request *request, protocol_response *response); public: - cluster_client(client_group* group); + cluster_client(client_group* group, unsigned int conn_id); virtual ~cluster_client(); virtual get_key_response get_key_for_conn(unsigned int command_index, unsigned int conn_id, unsigned long long* key_index); diff --git a/config_types.h b/config_types.h index 323d6a7..683b770 100644 --- a/config_types.h +++ b/config_types.h @@ -105,12 +105,14 @@ struct server_addr { #define KEY_PLACEHOLDER "__key__" #define DATA_PLACEHOLDER "__data__" +#define CONN_PLACEHOLDER "__conn_id__" enum command_arg_type { const_type = 0, key_type = 1, data_type = 2, - undefined_type = 3 + conn_id_type = 3, + undefined_type = 4 }; struct command_arg { diff --git a/memtier_benchmark.cpp b/memtier_benchmark.cpp index 0982354..3ea8f1d 100755 --- a/memtier_benchmark.cpp +++ b/memtier_benchmark.cpp @@ -1230,7 +1230,8 @@ struct cg_thread { m_protocol = protocol_factory(m_config->protocol); assert(m_protocol != NULL); - m_cg = new client_group(m_config, m_protocol, m_obj_gen); + // Pass thread_id to client_group + m_cg = new client_group(m_config, m_protocol, m_obj_gen, m_thread_id); } ~cg_thread() diff --git a/protocol.cpp b/protocol.cpp index 5bb14dc..c8f161c 100644 --- a/protocol.cpp +++ b/protocol.cpp @@ -1,3 +1,4 @@ +#include /* * Copyright (C) 2011-2017 Redis Labs Ltd. * @@ -175,7 +176,7 @@ class redis_protocol : public abstract_protocol { redis_protocol() : m_response_state(rs_initial), m_bulk_len(0), m_response_len(0), m_total_bulks_count(0), m_current_mbulk(NULL), m_resp3(false), m_attribute(false) { } virtual redis_protocol* clone(void) { return new redis_protocol(); } virtual int select_db(int db); - virtual int authenticate(const char *credentials); + virtual int authenticate(const char *user, const char *credentials); virtual int configure_protocol(enum PROTOCOL_TYPE type); virtual int write_command_cluster_slots(); virtual int write_command_set(const char *key, int key_len, const char *value, int value_len, int expiry, unsigned int offset); @@ -206,7 +207,7 @@ int redis_protocol::select_db(int db) return size; } -int redis_protocol::authenticate(const char *credentials) +int redis_protocol::authenticate(const char *user, const char *credentials) { int size = 0; assert(credentials != NULL); @@ -219,7 +220,6 @@ int redis_protocol::authenticate(const char *credentials) * contains a colon. */ - const char *user = NULL; const char *password; if (credentials[0] == ':') { @@ -229,12 +229,11 @@ int redis_protocol::authenticate(const char *credentials) if (!password) { password = credentials; } else { - user = credentials; password++; } } - if (!user) { + if (!user || strlen(user) == 0) { size = evbuffer_add_printf(m_write_buf, "*2\r\n" "$4\r\n" @@ -243,17 +242,16 @@ int redis_protocol::authenticate(const char *credentials) "%s\r\n", strlen(password), password); } else { - size_t user_len = password - user - 1; + size_t user_len = strlen(user); size = evbuffer_add_printf(m_write_buf, "*3\r\n" "$4\r\n" "AUTH\r\n" "$%zu\r\n" - "%.*s\r\n" + "%s\r\n" "$%zu\r\n" "%s\r\n", user_len, - (int) user_len, user, strlen(password), password); @@ -723,8 +721,10 @@ bool redis_protocol::format_arbitrary_command(arbitrary_command &cmd) { benchmark_error_log("error: data placeholder can't combined with other data\n"); return false; } - current_arg->type = data_type; + } else if (current_arg->data.find(CONN_PLACEHOLDER) != std::string::npos) { + // Allow conn_id placeholder to be combined with other text + current_arg->type = conn_id_type; } // we expect that first arg is the COMMAND name @@ -761,7 +761,7 @@ class memcache_text_protocol : public abstract_protocol { memcache_text_protocol() : m_response_state(rs_initial), m_value_len(0), m_response_len(0) { } virtual memcache_text_protocol* clone(void) { return new memcache_text_protocol(); } virtual int select_db(int db); - virtual int authenticate(const char *credentials); + virtual int authenticate(const char *user, const char *credentials); virtual int configure_protocol(enum PROTOCOL_TYPE type); virtual int write_command_cluster_slots(); virtual int write_command_set(const char *key, int key_len, const char *value, int value_len, int expiry, unsigned int offset); @@ -782,7 +782,7 @@ int memcache_text_protocol::select_db(int db) assert(0); } -int memcache_text_protocol::authenticate(const char *credentials) +int memcache_text_protocol::authenticate(const char *user, const char *credentials) { assert(0); } @@ -983,7 +983,7 @@ class memcache_binary_protocol : public abstract_protocol { memcache_binary_protocol() : m_response_state(rs_initial), m_response_len(0) { } virtual memcache_binary_protocol* clone(void) { return new memcache_binary_protocol(); } virtual int select_db(int db); - virtual int authenticate(const char *credentials); + virtual int authenticate(const char *user, const char *credentials); virtual int configure_protocol(enum PROTOCOL_TYPE type); virtual int write_command_cluster_slots(); virtual int write_command_set(const char *key, int key_len, const char *value, int value_len, int expiry, unsigned int offset); @@ -1003,14 +1003,13 @@ int memcache_binary_protocol::select_db(int db) assert(0); } -int memcache_binary_protocol::authenticate(const char *credentials) +int memcache_binary_protocol::authenticate(const char *user, const char *credentials) { protocol_binary_request_no_extras req; char nullbyte = '\0'; const char mechanism[] = "PLAIN"; int mechanism_len = sizeof(mechanism) - 1; const char *colon; - const char *user; int user_len; const char *passwd; int passwd_len; @@ -1019,8 +1018,8 @@ int memcache_binary_protocol::authenticate(const char *credentials) colon = strchr(credentials, ':'); assert(colon != NULL); - user = credentials; - user_len = colon - user; + // Use the user parameter instead of extracting from credentials + user_len = strlen(user); passwd = colon + 1; passwd_len = strlen(passwd); diff --git a/protocol.h b/protocol.h index 05b53bf..faab80d 100644 --- a/protocol.h +++ b/protocol.h @@ -183,7 +183,7 @@ class abstract_protocol { void set_keep_value(bool flag); virtual int select_db(int db) = 0; - virtual int authenticate(const char *credentials) = 0; + virtual int authenticate(const char *user, const char *credentials) = 0; virtual int configure_protocol(enum PROTOCOL_TYPE type) = 0; virtual int write_command_cluster_slots() = 0; virtual int write_command_set(const char *key, int key_len, const char *value, int value_len, int expiry, unsigned int offset) = 0; diff --git a/shard_connection.cpp b/shard_connection.cpp index e873308..32c5323 100644 --- a/shard_connection.cpp +++ b/shard_connection.cpp @@ -50,6 +50,8 @@ #include "connections_manager.h" #include "event2/bufferevent.h" +#include "client.h" + #ifdef USE_TLS #include #include @@ -127,14 +129,20 @@ verify_request::~verify_request(void) } shard_connection::shard_connection(unsigned int id, connections_manager* conns_man, benchmark_config* config, - struct event_base* event_base, abstract_protocol* abs_protocol) : - m_address(NULL), m_port(NULL), m_unix_sockaddr(NULL), + struct event_base* event_base, abstract_protocol* abs_protocol, + unsigned int conn_id) + : m_address(NULL), m_port(NULL), m_unix_sockaddr(NULL), m_bev(NULL), m_event_timer(NULL), m_request_per_cur_interval(0), m_pending_resp(0), m_connection_state(conn_disconnected), - m_hello(setup_done), m_authentication(setup_done), m_db_selection(setup_done), m_cluster_slots(setup_done) { + m_hello(setup_done), m_authentication(setup_done), m_db_selection(setup_done), m_cluster_slots(setup_done) +{ m_id = id; m_conns_manager = conns_man; m_config = config; m_event_base = event_base; + m_conn_id = conn_id; + + // Initialize connection ID string for fallback + m_conn_id_string = "user" + std::to_string(m_conn_id + 1); if (m_config->unix_socket) { m_unix_sockaddr = (struct sockaddr_un *) malloc(sizeof(struct sockaddr_un)); @@ -376,10 +384,10 @@ bool shard_connection::is_conn_setup_done() { m_hello == setup_done; } -void shard_connection::send_conn_setup_commands(struct timeval timestamp) { +void shard_connection::send_conn_setup_commands(struct timeval timestamp, const char* conn_id_string) { if (m_authentication == setup_none) { - benchmark_debug_log("sending authentication command.\n"); - m_protocol->authenticate(m_config->authenticate); + benchmark_debug_log("sending authentication command user: %s pass %s.\n",conn_id_string,m_config->authenticate); + m_protocol->authenticate(conn_id_string, m_config->authenticate); push_req(new request(rt_auth, 0, ×tamp, 0)); m_authentication = setup_sent; } @@ -521,7 +529,15 @@ void shard_connection::fill_pipeline(void) while (!m_conns_manager->finished() && m_pipeline->size() < m_config->pipeline) { if (!is_conn_setup_done()) { - send_conn_setup_commands(now); + // Get username from connections_manager (client object) + const char* conn_id_string = nullptr; + if (m_conns_manager) { + conn_id_string = static_cast(m_conns_manager)->get_conn_id_value(); + } else { + // fallback: use stored connection ID string + conn_id_string = m_conn_id_string.c_str(); + } + send_conn_setup_commands(now, conn_id_string); return; } diff --git a/shard_connection.h b/shard_connection.h index 12fbaae..24e0b1c 100644 --- a/shard_connection.h +++ b/shard_connection.h @@ -82,7 +82,8 @@ class shard_connection { public: shard_connection(unsigned int id, connections_manager* conn_man, benchmark_config* config, - struct event_base* event_base, abstract_protocol* abs_protocol); + struct event_base* event_base, abstract_protocol* abs_protocol, + unsigned int conn_id); ~shard_connection(); void set_address_port(const char* address, const char* port); @@ -133,12 +134,15 @@ class shard_connection { } private: + unsigned int m_conn_id; + std::string m_conn_id_string; + void setup_event(int sockfd); int setup_socket(struct connect_info* addr); void set_readable_id(); bool is_conn_setup_done(); - void send_conn_setup_commands(struct timeval timestamp); + void send_conn_setup_commands(struct timeval timestamp, const char* conn_id_string); request* pop_req(); void push_req(request* req);