|
18 | 18 | # |
19 | 19 |
|
20 | 20 | $:.unshift File.dirname(__FILE__) + '/../lib' |
| 21 | +$:.unshift File.dirname(__FILE__) + '/../ext' |
21 | 22 | require 'thrift' |
| 23 | +require 'openssl' |
22 | 24 | $:.unshift File.dirname(__FILE__) + "/gen-rb" |
23 | 25 | require 'benchmark_service' |
24 | 26 |
|
25 | 27 | class Client |
26 | | - def initialize(host, port, clients_per_process, calls_per_client, log_exceptions) |
| 28 | + def initialize(host, port, clients_per_process, calls_per_client, log_exceptions, tls) |
27 | 29 | @host = host |
28 | 30 | @port = port |
29 | 31 | @clients_per_process = clients_per_process |
30 | 32 | @calls_per_client = calls_per_client |
31 | 33 | @log_exceptions = log_exceptions |
| 34 | + @tls = tls |
32 | 35 | end |
33 | 36 |
|
34 | 37 | def run |
35 | 38 | @clients_per_process.times do |
36 | | - socket = Thrift::Socket.new(@host, @port) |
| 39 | + socket = if @tls |
| 40 | + ssl_context = OpenSSL::SSL::SSLContext.new.tap do |ctx| |
| 41 | + ctx.verify_mode = OpenSSL::SSL::VERIFY_PEER |
| 42 | + ctx.min_version = OpenSSL::SSL::TLS1_2_VERSION |
| 43 | + |
| 44 | + keys_dir = File.expand_path("../../../test/keys", __dir__) |
| 45 | + ctx.ca_file = File.join(keys_dir, "CA.pem") |
| 46 | + ctx.cert = OpenSSL::X509::Certificate.new(File.open(File.join(keys_dir, "client.crt"))) |
| 47 | + ctx.cert_store = OpenSSL::X509::Store.new |
| 48 | + ctx.cert_store.add_file(File.join(keys_dir, 'server.pem')) |
| 49 | + ctx.key = OpenSSL::PKey::RSA.new(File.open(File.join(keys_dir, "client.key"))) |
| 50 | + end |
| 51 | + |
| 52 | + Thrift::SSLSocket.new(@host, @port, nil, ssl_context) |
| 53 | + else |
| 54 | + Thrift::Socket.new(@host, @port) |
| 55 | + end |
37 | 56 | transport = Thrift::FramedTransport.new(socket) |
38 | 57 | protocol = Thrift::BinaryProtocol.new(transport) |
39 | 58 | client = ThriftBenchmark::BenchmarkService::Client.new(protocol) |
@@ -68,7 +87,8 @@ def print_exception(e) |
68 | 87 | end |
69 | 88 |
|
70 | 89 | log_exceptions = true if ARGV[0] == '-log-exceptions' and ARGV.shift |
| 90 | +tls = true if ARGV[0] == '-tls' and ARGV.shift |
71 | 91 |
|
72 | 92 | host, port, clients_per_process, calls_per_client = ARGV |
73 | 93 |
|
74 | | -Client.new(host, port.to_i, clients_per_process.to_i, calls_per_client.to_i, log_exceptions).run |
| 94 | +Client.new(host, port.to_i, clients_per_process.to_i, calls_per_client.to_i, log_exceptions, tls).run |
0 commit comments