Skip to content

Commit 5d19ace

Browse files
committed
Add TLS support to the benchmark script
1 parent b8f7e5b commit 5d19ace

File tree

3 files changed

+55
-8
lines changed

3 files changed

+55
-8
lines changed

lib/rb/benchmark/benchmark.rb

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
require 'rubygems'
2121
$:.unshift File.dirname(__FILE__) + '/../lib'
22+
$:.unshift File.dirname(__FILE__) + '/../ext'
2223
require 'thrift'
2324
require 'stringio'
2425

@@ -40,12 +41,13 @@ def initialize(opts)
4041
@interpreter = opts.fetch(:interpreter, "ruby")
4142
@host = opts.fetch(:host, ::HOST)
4243
@port = opts.fetch(:port, ::PORT)
44+
@tls = opts.fetch(:tls, false)
4345
end
4446

4547
def start
4648
return if @serverclass == Object
4749
args = (File.basename(@interpreter) == "jruby" ? "-J-server" : "")
48-
@pipe = IO.popen("#{@interpreter} #{args} #{File.dirname(__FILE__)}/server.rb #{@host} #{@port} #{@serverclass.name}", "r+")
50+
@pipe = IO.popen("#{@interpreter} #{args} #{File.dirname(__FILE__)}/server.rb #{"-tls" if @tls} #{@host} #{@port} #{@serverclass.name}", "r+")
4951
Marshal.load(@pipe) # wait until the server has started
5052
sleep 0.4 # give the server time to actually start spawning sockets
5153
end
@@ -75,6 +77,7 @@ def initialize(opts, server)
7577
@interpreter = opts.fetch(:interpreter, "ruby")
7678
@server = server
7779
@log_exceptions = opts.fetch(:log_exceptions, false)
80+
@tls = opts.fetch(:tls, false)
7881
end
7982

8083
def run
@@ -93,13 +96,15 @@ def run
9396
end
9497

9598
def spawn
96-
pipe = IO.popen("#{@interpreter} #{File.dirname(__FILE__)}/client.rb #{"-log-exceptions" if @log_exceptions} #{@host} #{@port} #{@clients_per_process} #{@calls_per_client}")
99+
pipe = IO.popen("#{@interpreter} #{File.dirname(__FILE__)}/client.rb #{"-log-exceptions" if @log_exceptions} #{"-tls" if @tls} #{@host} #{@port} #{@clients_per_process} #{@calls_per_client}")
97100
@pool << pipe
98101
end
99102

100103
def socket_class
101104
if @socket
102105
Thrift::UNIXSocket
106+
elsif @tls
107+
Thrift::SSLSocket
103108
else
104109
Thrift::Socket
105110
end
@@ -255,12 +260,14 @@ def resolve_const(const)
255260
args[:class] = resolve_const(ENV['THRIFT_SERVER']) || Thrift::NonblockingServer
256261
args[:host] = ENV['THRIFT_HOST'] || HOST
257262
args[:port] = (ENV['THRIFT_PORT'] || PORT).to_i
263+
args[:tls] = ENV['THRIFT_TLS'] == 'true'
258264
server = Server.new(args)
259265
server.start
260266

261267
args = {}
262268
args[:host] = ENV['THRIFT_HOST'] || HOST
263269
args[:port] = (ENV['THRIFT_PORT'] || PORT).to_i
270+
args[:tls] = ENV['THRIFT_TLS'] == 'true'
264271
args[:num_processes] = (ENV['THRIFT_NUM_PROCESSES'] || 40).to_i
265272
args[:clients_per_process] = (ENV['THRIFT_NUM_CLIENTS'] || 5).to_i
266273
args[:calls_per_client] = (ENV['THRIFT_NUM_CALLS'] || 50).to_i

lib/rb/benchmark/client.rb

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,41 @@
1818
#
1919

2020
$:.unshift File.dirname(__FILE__) + '/../lib'
21+
$:.unshift File.dirname(__FILE__) + '/../ext'
2122
require 'thrift'
23+
require 'openssl'
2224
$:.unshift File.dirname(__FILE__) + "/gen-rb"
2325
require 'benchmark_service'
2426

2527
class Client
26-
def initialize(host, port, clients_per_process, calls_per_client, log_exceptions)
28+
def initialize(host, port, clients_per_process, calls_per_client, log_exceptions, tls)
2729
@host = host
2830
@port = port
2931
@clients_per_process = clients_per_process
3032
@calls_per_client = calls_per_client
3133
@log_exceptions = log_exceptions
34+
@tls = tls
3235
end
3336

3437
def run
3538
@clients_per_process.times do
36-
socket = Thrift::Socket.new(@host, @port)
39+
socket = if @tls
40+
ssl_context = OpenSSL::SSL::SSLContext.new.tap do |ctx|
41+
ctx.verify_mode = OpenSSL::SSL::VERIFY_PEER
42+
ctx.min_version = OpenSSL::SSL::TLS1_2_VERSION
43+
44+
keys_dir = File.expand_path("../../../test/keys", __dir__)
45+
ctx.ca_file = File.join(keys_dir, "CA.pem")
46+
ctx.cert = OpenSSL::X509::Certificate.new(File.open(File.join(keys_dir, "client.crt")))
47+
ctx.cert_store = OpenSSL::X509::Store.new
48+
ctx.cert_store.add_file(File.join(keys_dir, 'server.pem'))
49+
ctx.key = OpenSSL::PKey::RSA.new(File.open(File.join(keys_dir, "client.key")))
50+
end
51+
52+
Thrift::SSLSocket.new(@host, @port, nil, ssl_context)
53+
else
54+
Thrift::Socket.new(@host, @port)
55+
end
3756
transport = Thrift::FramedTransport.new(socket)
3857
protocol = Thrift::BinaryProtocol.new(transport)
3958
client = ThriftBenchmark::BenchmarkService::Client.new(protocol)
@@ -68,7 +87,8 @@ def print_exception(e)
6887
end
6988

7089
log_exceptions = true if ARGV[0] == '-log-exceptions' and ARGV.shift
90+
tls = true if ARGV[0] == '-tls' and ARGV.shift
7191

7292
host, port, clients_per_process, calls_per_client = ARGV
7393

74-
Client.new(host, port.to_i, clients_per_process.to_i, calls_per_client.to_i, log_exceptions).run
94+
Client.new(host, port.to_i, clients_per_process.to_i, calls_per_client.to_i, log_exceptions, tls).run

lib/rb/benchmark/server.rb

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
#
1919

2020
$:.unshift File.dirname(__FILE__) + '/../lib'
21+
$:.unshift File.dirname(__FILE__) + '/../ext'
2122
require 'thrift'
23+
require 'openssl'
2224
$:.unshift File.dirname(__FILE__) + "/gen-rb"
2325
require 'benchmark_service'
2426

@@ -36,10 +38,26 @@ def fibonacci(n)
3638
end
3739
end
3840

39-
def self.start_server(host, port, serverClass)
41+
def self.start_server(host, port, serverClass, tls)
4042
handler = BenchmarkHandler.new
4143
processor = ThriftBenchmark::BenchmarkService::Processor.new(handler)
42-
transport = ServerSocket.new(host, port)
44+
transport = if tls
45+
ssl_context = OpenSSL::SSL::SSLContext.new.tap do |ctx|
46+
ctx.verify_mode = OpenSSL::SSL::VERIFY_PEER
47+
ctx.min_version = OpenSSL::SSL::TLS1_2_VERSION
48+
49+
keys_dir = File.expand_path("../../../test/keys", __dir__)
50+
ctx.ca_file = File.join(keys_dir, "CA.pem")
51+
ctx.cert = OpenSSL::X509::Certificate.new(File.open(File.join(keys_dir, "server.crt")))
52+
ctx.cert_store = OpenSSL::X509::Store.new
53+
ctx.cert_store.add_file(File.join(keys_dir, 'client.pem'))
54+
ctx.key = OpenSSL::PKey::RSA.new(File.open(File.join(keys_dir, "server.key")))
55+
end
56+
57+
Thrift::SSLServerSocket.new(host, port, ssl_context)
58+
else
59+
ServerSocket.new(host, port)
60+
end
4361
transport_factory = FramedTransportFactory.new
4462
args = [processor, transport, transport_factory, nil, 20]
4563
if serverClass == NonblockingServer
@@ -68,9 +86,11 @@ def resolve_const(const)
6886
const and const.split('::').inject(Object) { |k,c| k.const_get(c) }
6987
end
7088

89+
tls = true if ARGV[0] == '-tls' and ARGV.shift
90+
7191
host, port, serverklass = ARGV
7292

73-
Server.start_server(host, port.to_i, resolve_const(serverklass))
93+
Server.start_server(host, port.to_i, resolve_const(serverklass), tls)
7494

7595
# let our host know that the interpreter has started
7696
# ideally we'd wait until the server was serving, but we don't have a hook for that

0 commit comments

Comments
 (0)