diff --git a/.github/workflows/elixir.yml b/.github/workflows/elixir.yml index 6016e2a..61a3fcb 100644 --- a/.github/workflows/elixir.yml +++ b/.github/workflows/elixir.yml @@ -43,7 +43,7 @@ jobs: sudo apt-get install -y mssql-tools unixodbc-dev - uses: actions/checkout@v2 - name: Setup elixir - uses: actions/setup-elixir@v1 + uses: erlef/setup-elixir@v1 with: otp-version: ${{matrix.otp}} elixir-version: ${{matrix.elixir}} diff --git a/README.md b/README.md index 740005e..a518582 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,22 @@ config :your_app, :tds_conn, port: 1433 ``` +or with ssl + +```elixir +import Mix.Config + +config :your_app, :tds_conn, + hostname: "localhost", + username: "test_user", + password: "test_password", + database: "test_db", + port: 1433, + ssl: true, + ssl_opts: [] # add key or leave empty for selfsigned certs, accepts :ssl.client_option() + +``` + Then using `Application.get_env(:your_app, :tds_conn)` use this as first parameter in `Tds.start_link/1` function. There is additional parameter that can be used in configuration and diff --git a/config/dev.exs b/config/dev.exs index 38c3ae5..27ee941 100644 --- a/config/dev.exs +++ b/config/dev.exs @@ -1,3 +1,6 @@ use Mix.Config config :elixir, :time_zone_database, Tzdata.TimeZoneDatabase + +config :tds, + opts: [hostname: "nitrox", username: "sa", password: "some!Password", database: "test", ssl: true, ssl_opts: [certfile: "/Users/mjaric/prj/github/tds/mssql.pem", keyfile: "/Users/mjaric/prj/github/tds/mssql.key"]] diff --git a/config/test.exs b/config/test.exs index ede1271..463a563 100644 --- a/config/test.exs +++ b/config/test.exs @@ -12,4 +12,5 @@ config :tds, database: "test", trace: false, set_allow_snapshot_isolation: :on + # show_sensitive_data_on_connection_error: true ] diff --git a/lib/ntlm.ex b/lib/ntlm.ex new file mode 100644 index 0000000..78ec70e --- /dev/null +++ b/lib/ntlm.ex @@ -0,0 +1,267 @@ +defmodule Ntlm do + @moduledoc """ + This module provides encoders and decoders for NTLM + negotiation and authentincation + """ + require Bitwise + + @ntlm_NegotiateUnicode 0x00000001 + @ntlm_NegotiateOEM 0x00000002 + @ntlm_RequestTarget 0x00000004 + @ntlm_Unknown9 0x00000008 + @ntlm_NegotiateSign 0x00000010 + @ntlm_NegotiateSeal 0x00000020 + @ntlm_NegotiateDatagram 0x00000040 + @ntlm_NegotiateLanManagerKey 0x00000080 + @ntlm_Unknown8 0x00000100 + @ntlm_NegotiateNTLM 0x00000200 + @ntlm_NegotiateNTOnly 0x00000400 + @ntlm_Anonymous 0x00000800 + @ntlm_NegotiateOemDomainSupplied 0x00001000 + @ntlm_NegotiateOemWorkstationSupplied 0x00002000 + @ntlm_Unknown6 0x00004000 + @ntlm_NegotiateAlwaysSign 0x00008000 + @ntlm_TargetTypeDomain 0x00010000 + @ntlm_TargetTypeServer 0x00020000 + @ntlm_TargetTypeShare 0x00040000 + @ntlm_NegotiateExtendedSecurity 0x00080000 + @ntlm_NegotiateIdentify 0x00100000 + @ntlm_Unknown5 0x00200000 + @ntlm_RequestNonNTSessionKey 0x00400000 + @ntlm_NegotiateTargetInfo 0x00800000 + @ntlm_Unknown4 0x01000000 + @ntlm_NegotiateVersion 0x02000000 + @ntlm_Unknown3 0x04000000 + @ntlm_Unknown2 0x08000000 + @ntlm_Unknown1 0x10000000 + @ntlm_Negotiate128 0x20000000 + @ntlm_NegotiateKeyExchange 0x40000000 + @ntlm_Negotiate56 0x80000000 + + @type domain :: String.t() + @type username :: String.t() + @type password :: String.t() + @type negotiation_option :: {:domain, domain()} | {:workstation, String.t()} + @type negotiation_options :: [negotiation_option()] + + @doc """ + Builds NTLM negotiation message `<<"NTLMSSP", 0x00, 0x01 ...>>` + + - `opts` - is a `Keyword.t` list that requires `:domain` key and accepts + optinal `:workstation` string. Both values can only contain valid ASCII + characters + """ + @spec negotiate(negotiation_options) :: binary() + def negotiate(negotiation_options) do + fixed_data_len = 40 + + domain = + :unicode.characters_to_binary( + negotiation_options[:domain], + :unicode, + :latin1 + ) + + domain_length = String.length(negotiation_options[:domain]) + + workstation = + negotiation_options + |> Keyword.get(:workstation) + |> Kernel.||("") + + workstation_length = String.length(workstation) + workstation = :unicode.characters_to_binary(workstation, :unicode, :latin1) + + type1_flags = type1_flags(workstation != <<>>) + + << + "NTLMSSP", + 0x00, + 0x01::little-unsigned-32, + type1_flags::little-unsigned-32, + domain_length::little-unsigned-16, + domain_length::little-unsigned-16, + fixed_data_len + workstation_length::little-unsigned-32, + workstation_length::little-unsigned-16, + workstation_length::little-unsigned-16, + fixed_data_len::little-unsigned-32, + 5, + 0, + 2195::little-unsigned-16, + 0, + 0, + 0, + 15, + domain::binary-size(domain_length)-unit(8), + workstation::binary-size(workstation_length)-unit(8) + >> + end + + @spec authenticate(domain(), username(), password(), binary(), binary()) :: + binary() + def authenticate(domain, username, password, server_data, server_nonce) do + domain = ucs2(domain) + domain_len = byte_size(domain) + username = ucs2(username) + username_len = byte_size(username) + lmv2_len = 24 + ntlmv2_len = 16 + base_idx = 64 + dn_idx = base_idx + un_idx = dn_idx + domain_len + l2_idx = un_idx + username_len * 2 + nt_idx = l2_idx + lmv2_len + client_nonce = client_nonce() + + gen_time = + NaiveDateTime.utc_now() + |> DateTime.from_naive!("Etc/UTC") + |> DateTime.to_unix() + + fixed = + <<"NTLMSSP", 0, 0x03::little-unsigned-32, lmv2_len::little-unsigned-16, + l2_idx::little-unsigned-32, ntlmv2_len::little-unsigned-16, + ntlmv2_len::little-unsigned-16, nt_idx::little-unsigned-32, + domain_len::little-unsigned-16, domain_len::little-unsigned-16, + dn_idx::little-unsigned-32, username_len::little-unsigned-16, + username_len::little-unsigned-16, un_idx::little-unsigned-32, + 0x00::little-unsigned-16, 0x00::little-unsigned-16, + base_idx::little-unsigned-32, 0x00::little-unsigned-16, + 0x00::little-unsigned-16, base_idx::little-unsigned-32, + 0x8201::little-unsigned-16, 0x00::little-unsigned-16>> + + [ + fixed, + domain, + username, + lvm2_response(domain, username, password, server_nonce, client_nonce), + ntlmv2_response( + domain, + username, + password, + server_nonce, + server_data, + client_nonce, + gen_time + ), + [0x01, 0x01, 0x00, 0x00], + as_timestamp(gen_time), + client_nonce, + [0x00, 0x00], + server_data, + [0x00, 0x00] + ] + |> IO.iodata_to_binary() + end + + defp lvm2_response(domain, username, password, server_nonce, client_nonce) do + hash = ntv2_hash(domain, username, password) + data = server_nonce <> client_nonce + new_hash = hmac_md5(data, hash) + [new_hash, client_nonce] + end + + defp ntlmv2_response( + domain, + username, + password, + server_nonce, + server_data, + client_nonce, + gen_time + ) do + timestamp = as_timestamp(gen_time) + hash = ntv2_hash(domain, username, password) + target_info_len = byte_size(server_data) + data = << + server_nonce::binary-size(8)-unit(8), + 0x0101::little-unsigned-32, + 0x0000::little-unsigned-32, + timestamp::binary-size(8)-unit(8), + client_nonce::binary-size(8)-unit(8), + 0x0000::unsigned-32, + server_data::binary-size(target_info_len)-unit(8), + 0x0000::little-unsigned-32 + >> + + hmac_md5(data, hash) + end + + defp client_nonce() do + 1..8 + |> Enum.map(fn _ -> :rand.uniform(255) end) + |> IO.iodata_to_binary() + end + + defp type1_flags(workstation?) do + 0x00000000 + |> Bitwise.bor(@ntlm_NegotiateUnicode) + |> Bitwise.bor(@ntlm_NegotiateOEM) + |> Bitwise.bor(@ntlm_RequestTarget) + |> Bitwise.bor(@ntlm_Unknown9) + |> Bitwise.bor(@ntlm_NegotiateSign) + |> Bitwise.bor(@ntlm_NegotiateSeal) + |> Bitwise.bor(@ntlm_NegotiateDatagram) + |> Bitwise.bor(@ntlm_NegotiateLanManagerKey) + |> Bitwise.bor(@ntlm_Unknown8) + |> Bitwise.bor(@ntlm_NegotiateNTLM) + |> Bitwise.bor(@ntlm_NegotiateNTOnly) + |> Bitwise.bor(@ntlm_Anonymous) + |> Bitwise.bor(@ntlm_NegotiateOemDomainSupplied) + |> Bitwise.bor( + if(workstation?, + do: @ntlm_NegotiateOemWorkstationSupplied, + else: 0x00000000 + ) + ) + |> Bitwise.bor(@ntlm_Unknown6) + |> Bitwise.bor(@ntlm_NegotiateAlwaysSign) + |> Bitwise.bor(@ntlm_TargetTypeDomain) + |> Bitwise.bor(@ntlm_TargetTypeServer) + |> Bitwise.bor(@ntlm_TargetTypeShare) + |> Bitwise.bor(@ntlm_NegotiateExtendedSecurity) + |> Bitwise.bor(@ntlm_NegotiateIdentify) + |> Bitwise.bor(@ntlm_Unknown5) + |> Bitwise.bor(@ntlm_RequestNonNTSessionKey) + |> Bitwise.bor(@ntlm_NegotiateTargetInfo) + |> Bitwise.bor(@ntlm_Unknown4) + |> Bitwise.bor(@ntlm_NegotiateVersion) + |> Bitwise.bor(@ntlm_Unknown3) + |> Bitwise.bor(@ntlm_Unknown2) + |> Bitwise.bor(@ntlm_Unknown1) + |> Bitwise.bor(@ntlm_Negotiate128) + |> Bitwise.bor(@ntlm_NegotiateKeyExchange) + |> Bitwise.bor(@ntlm_Negotiate56) + end + + defp as_timestamp(unix) do + tenth_of_usec = (unix + 11_644_473_600) * 10_000_000 + lo = Bitwise.band(tenth_of_usec, 0xFFFFFFFF) + + hi = + tenth_of_usec + |> Bitwise.>>>(32) + |> Bitwise.band(0xFFFFFFFF) + + <> + end + + defp ntv2_hash(domain, user, password) do + hash = nt_hash(password) + identity = ucs2(String.upcase(user) <> String.upcase(domain)) + hmac_md5(identity, hash) + end + + defp nt_hash(text) do + text = ucs2(text) + :crypto.hash(:md4, text) + end + + defp hmac_md5(data, key) do + :crypto.hmac(:md5, key, data) + end + + defp ucs2(str) do + :unicode.characters_to_binary(str, :unicode, {:utf16, :little}) + end +end diff --git a/lib/tds/messages.ex b/lib/tds/messages.ex index 9a09d2f..a9b9873 100644 --- a/lib/tds/messages.ex +++ b/lib/tds/messages.ex @@ -19,6 +19,7 @@ defmodule Tds.Messages do defrecord :msg_attn, [] # responses + defrecord :msg_preloginack, [:response] defrecord :msg_loginack, [:redirect] defrecord :msg_prepared, [:params] defrecord :msg_sql_result, [:columns, :rows, :row_count] @@ -61,6 +62,14 @@ defmodule Tds.Messages do ## Parsers + def parse(:prelogin, packet_data, s) do + response = + packet_data + |> Tds.Protocol.Prelogin.decode(s) + + {msg_preloginack(response: response), s} + end + def parse(:login, packet_data, s) do packet_data |> decode_tokens() @@ -236,16 +245,16 @@ defmodule Tds.Messages do encode(msg, env) end - defp encode(msg_prelogin(params: _params), _env) do - version_data = <<11, 0, 12, 56, 0, 0>> - version_length = byte_size(version_data) - version_offset = 0x06 - version = <<0x00, version_offset::size(16), version_length::size(16)>> - terminator = <<0xFF>> - prelogin_data = version_data - data = version <> terminator <> prelogin_data - encode_packets(0x12, data) - # encode_header(0x12, data) <> data + defp encode(msg_prelogin(params: opts), _env) do + # version_data = <<11, 0, 12, 56, 0, 0>> + # version_length = byte_size(version_data) + # version_offset = 0x06 + # version = <<0x00, version_offset::size(16), version_length::size(16)>> + # terminator = <<0xFF>> + # prelogin_data = version_data + # data = version <> terminator <> prelogin_data + # encode_packets(0x12, data) + Tds.Protocol.Prelogin.encode(opts) end defp encode(msg_login(params: params), _env) do @@ -292,7 +301,7 @@ defmodule Tds.Messages do # to by IbPassword, the client SHOULD first swap the four high bits with the # four low bits and then do a bit-XOR with 0xA5 (10100101). - clt_int_name = "ODBC" + clt_int_name = "tdsx" clt_int_name_ucs = to_little_ucs2(clt_int_name) database = params[:database] || "" database_ucs = to_little_ucs2(database) @@ -442,7 +451,13 @@ defmodule Tds.Messages do encode_packets(0x03, data) end - defp encode(msg_transmgr(command: "TM_BEGIN_XACT", isolation_level: isolation_level), %{trans: trans}) do + defp encode( + msg_transmgr( + command: "TM_BEGIN_XACT", + isolation_level: isolation_level + ), + %{trans: trans} + ) do isolation = encode_isolation_level(isolation_level) encode_trans(5, trans, <>) end @@ -451,15 +466,21 @@ defmodule Tds.Messages do encode_trans(7, trans, <<0, 0>>) end - defp encode(msg_transmgr(command: "TM_ROLLBACK_XACT", name: name), %{trans: trans}) do - payload = unless name > 0, - do: <<0x00::size(2)-unit(8)>>, - else: <<2::unsigned-8, name::little-size(2)-unit(8), 0x0::size(1)-unit(8)>> + defp encode(msg_transmgr(command: "TM_ROLLBACK_XACT", name: name), %{ + trans: trans + }) do + payload = + unless name > 0, + do: <<0x00::size(2)-unit(8)>>, + else: + <<2::unsigned-8, name::little-size(2)-unit(8), 0x0::size(1)-unit(8)>> encode_trans(8, trans, payload) end - defp encode(msg_transmgr(command: "TM_SAVE_XACT", name: savepoint), %{trans: trans}) do + defp encode(msg_transmgr(command: "TM_SAVE_XACT", name: savepoint), %{ + trans: trans + }) do encode_trans(9, trans, <<2::unsigned-8, savepoint::little-size(2)-unit(8)>>) end @@ -493,7 +514,8 @@ defmodule Tds.Messages do all_headers = <> <> headers data = - all_headers <> <> + all_headers <> + <> encode_packets(0x0E, data) end @@ -519,7 +541,7 @@ defmodule Tds.Messages do # for that parameter. Otherwise RPC will fail and we must use ProceName # instead. But we want to avoid execution overhead with named approach # hence ommiting @handle from parameter name - %{p| name: ""} + %{p | name: ""} p -> # other paramters should be named diff --git a/lib/tds/protocol.ex b/lib/tds/protocol.ex index 4dc4d2e..1f036c4 100644 --- a/lib/tds/protocol.ex +++ b/lib/tds/protocol.ex @@ -23,7 +23,7 @@ defmodule Tds.Protocol do :serializable ] - @type sock :: {:gen_tcp | :ssl, pid} + @type sock :: {:gen_tcp | :ssl, port()} @type env :: %{ trans: <<_::8>>, savepoint: non_neg_integer, @@ -42,7 +42,7 @@ defmodule Tds.Protocol do @type t :: %__MODULE__{ sock: nil | sock, usock: nil | pid, - itcp: term, + itcp: non_neg_integer() | String.t(), opts: nil | Keyword.t(), state: state, result: nil | list(), @@ -53,6 +53,7 @@ defmodule Tds.Protocol do defstruct sock: nil, usock: nil, + # instance port itcp: nil, opts: nil, # Tells if connection is ready or executing command @@ -140,10 +141,10 @@ defmodule Tds.Protocol do {:disconnect, err, s} end - def checkout(%{sock: {mod, sock}} = s) do + def checkout(%{sock: {mod, _sock}} = s) do sock_mod = inspect(mod) - case :inet.setopts(sock, active: false) do + case setopts(s.sock, active: false) do :ok -> {:ok, s} @@ -162,10 +163,10 @@ defmodule Tds.Protocol do {:disconnect, err, s} end - def checkin(%{sock: {mod, sock}} = s) do + def checkin(%{sock: {mod, _sock}} = s) do sock_mod = inspect(mod) - case :inet.setopts(sock, active: :once) do + case setopts(s.sock, active: :once) do :ok -> {:ok, s} @@ -236,7 +237,7 @@ defmodule Tds.Protocol do end @impl DBConnection - @spec handle_close(Tds.Query.t(), nil | keyword | map, t()) :: + @spec handle_close(Tds.Query.t(), nil | maybe_improper_list() | map(), t()) :: {:ok, Tds.Result.t(), new_state :: t()} | {:error | :disconnect, Exception.t(), new_state :: t()} def handle_close(query, opts, s) do @@ -405,13 +406,12 @@ defmodule Tds.Protocol do :ok = :inet.setopts(sock, buffer: buffer) - case login(%{s | sock: {:gen_tcp, sock}}) do + case prelogin(%{s | sock: {:gen_tcp, sock}}) do {:error, error, _state} -> - :gen_tcp.close(sock) {:error, error} - r -> - r + other -> + other end {:error, error} -> @@ -462,6 +462,29 @@ defmodule Tds.Protocol do end end + ## ssl + + defp ssl_connect(%{sock: {:gen_tcp, sock}, opts: opts} = s) do + {:ok, _} = Application.ensure_all_started(:ssl) + :inet.setopts(sock, active: false) + + case Tds.Tls.connect(sock, opts[:ssl_opts] || []) do + {:ok, ssl_sock} -> + state = %{s | sock: {:ssl, ssl_sock} } + {:ok, state} + + {:error, reason} -> + error = + Tds.Error.exception( + "Unable to establish secure connection to server due #{ + inspect(reason) + }" + ) + :gen_tcp.close(sock) + {:error, error, s} + end + end + def handle_info({:udp_error, _, :econnreset}, s) do msg = "Tds encountered an error while connecting to the Sql Server " <> @@ -474,10 +497,7 @@ defmodule Tds.Protocol do {:tcp, _, _data}, %{sock: {mod, sock}, opts: opts, state: :prelogin} = s ) do - case mod do - :gen_tcp -> :inet.setopts(sock, active: false) - :ssl -> :ssl.setopts(sock, active: false) - end + setopts(s.sock, active: false) login(%{s | opts: opts, sock: {mod, sock}}) end @@ -544,22 +564,16 @@ defmodule Tds.Protocol do def prelogin(%{opts: opts} = s) do msg = msg_prelogin(params: opts) - case msg_send(msg, s) do - {:ok, s} -> - {:noreply, %{s | state: :prelogin}} - - {:error, reason, s} -> - error(%Tds.Error{message: "tcp send: #{reason}"}, s) - - any -> - any + case msg_send(msg, %{s | state: :prelogin}) do + {:ok, s} -> login(s) + any -> any end end def login(%{opts: opts} = s) do msg = msg_login(params: opts) - case login_send(msg, s) do + case login_send(msg, %{s | state: :login}) do {:ok, s} -> {:ok, %{s | state: :ready}} @@ -763,6 +777,14 @@ defmodule Tds.Protocol do end end + def message(:prelogin, msg_preloginack(response: response), _) do + case response do + {:login, s} -> {:ok, s} + {:encrypt, s} -> ssl_connect(s) + other -> other + end + end + def message( :login, msg_loginack(redirect: %{hostname: host, port: port}), @@ -779,9 +801,7 @@ defmodule Tds.Protocol do connect(new_opts) end - def message(:login, msg_loginack(), %{opts: opts} = s) do - state = %{s | opts: clean_opts(opts)} - + def message(:login, msg_loginack(), %{opts: opts} = state) do opts |> conn_opts() |> IO.iodata_to_binary() @@ -837,8 +857,9 @@ defmodule Tds.Protocol do end # Send Command To Sql Server - defp login_send(msg, %{sock: {mod, sock}, env: env} = s) do + defp login_send(msg, %{sock: {mod, sock}, env: env, opts: opts} = s) do paks = encode_msg(msg, env) + s = %{s | opts: clean_opts(opts)} Enum.each(paks, fn pak -> mod.send(sock, pak) @@ -846,7 +867,7 @@ defmodule Tds.Protocol do case msg_recv(s) do {:disconnect, ex, s} -> - {:error, ex, s} + {:disconnect, ex, s} buffer -> buffer @@ -857,51 +878,33 @@ defmodule Tds.Protocol do defp msg_send( msg, - %{sock: {mod, sock}, env: env, state: state, opts: opts} = s + %{sock: {mod, port}, env: env, opts: opts} = s ) do - :inet.setopts(sock, active: false) + setopts(s.sock, active: false) opts |> Keyword.get(:use_elixir_calendar_types, false) |> use_elixir_calendar_types() - {t_send, _} = - :timer.tc(fn -> - msg - |> encode_msg(env) - |> Enum.each(&mod.send(sock, &1)) - end) - - {t_recv, {t_decode, result}} = - :timer.tc(fn -> - case msg_recv(s) do - {:disconnect, _ex, _s} = res -> - {0, res} - - buffer -> - :timer.tc(fn -> - buffer - |> IO.iodata_to_binary() - |> decode(s) - end) + send_result = + msg + |> encode_msg(env) + |> Enum.reduce_while(:ok, fn chunk, _ -> + case mod.send(port, chunk) do + {:error, reason} -> {:halt, {:error, reason}} + :ok -> {:cont, :ok} end end) - stm = Map.get(s, :query) - - if Keyword.get(s.opts, :trace, false) == true do - Logger.debug(fn -> - "[trace] [Tds.Protocod.msg_send/2] " <> - "state=#{inspect(state)} " <> - "send=#{Tds.Perf.to_string(t_send)} " <> - "receive=#{Tds.Perf.to_string(t_recv - t_decode)} " <> - "decode=#{Tds.Perf.to_string(t_decode)}" <> - "\n" <> - "#{inspect(stm)}" - end) + with :ok <- send_result, + buffer when is_list(buffer) <- msg_recv(s) do + buffer + |> IO.iodata_to_binary() + |> decode(s) + else + {:disconnect, _ex, _s} = res -> {0, res} + other -> other end - - result end defp msg_recv(%{sock: {mod, pid}} = s) do @@ -1175,4 +1178,11 @@ defmodule Tds.Protocol do ) end end + + defp setopts({mod, sock}, options) do + case mod do + :gen_tcp -> :inet.setopts(sock, options) + :ssl -> :ssl.setopts(sock, options) + end + end end diff --git a/lib/tds/protocol/prelogin.ex b/lib/tds/protocol/prelogin.ex new file mode 100644 index 0000000..8b55a1f --- /dev/null +++ b/lib/tds/protocol/prelogin.ex @@ -0,0 +1,284 @@ +defmodule Tds.Protocol.Prelogin do + @moduledoc false + import Tds.Protocol.Grammar + require Logger + + # defstruct version: nil, + + @type state :: Tds.Protocol.t() + @type packet_data :: iodata() + + @type response :: + {:ok, state()} + | {:error, Exception.t() | atom(), state()} + + defstruct version: nil, + encryption: <<0x00>>, + instance: true, + threadid: nil, + mars: false, + fedauth: false, + nonceopt: nil + + @type t :: %__MODULE__{ + version: tuple(), + encryption: <<_::8>>, + instance: boolean(), + mars: boolean() + } + + @version_token 0x00 + @encryption_token 0x01 + @instopt_token 0x02 + @threadid_token 0x03 + @mars_token 0x04 + @fedauth_token 0x06 + @nonceopt_token 0x07 + @termintator_token 0xFF + + # ENCODE + + @spec encode(maybe_improper_list()) :: [binary(), ...] + def encode(opts) do + stream = [ + encode_version(opts), + encode_encryption(opts), + # when instance id check is sent, encryption is not negotiated + # encode_instance(opts), + encode_threadid(opts), + encode_mars(opts), + encode_fedauth(opts) + ] + + start_offset = 5 * Enum.count(stream) + 1 + + {iodata, _} = + stream + |> Enum.reduce({[[], @termintator_token, []], start_offset}, fn + {token, option_data}, {[options, term, data], offset} -> + data_length = byte_size(option_data) + + options = [ + options, + <> + ] + + data = [data, option_data] + {[options, term, data], offset + data_length} + end) + + data = IO.iodata_to_binary(iodata) + Tds.Messages.encode_packets(0x12, data) + end + + defp encode_version(_opts) do + data = + Application.spec(:tds) + |> Keyword.get(:vsn) + |> to_string() + |> String.split(".") + |> Enum.map(&(Integer.parse(&1, 10) |> elem(0))) + |> case do + [major, minor, build] -> + <> + + [major, minor] -> + <<0x00, 0x00, minor, major, 0x00, 0x00>> + + _ -> + # probably PRE-release + <<0x01, 0x00, 0, 1, 0x00, 0x00>> + end + + {@version_token, data} + end + + defp encode_encryption(opts) do + data = + if Keyword.get(opts, :ssl, false), + do: <<0x01::byte>>, + else: <<0x02::byte>> + + {@encryption_token, data} + end + + defp encode_instance(opts) do + # not working for some reason + instance = Keyword.get(opts, :instance) + + if is_nil(instance) do + {@instopt_token, <<0x00>>} + else + {@instopt_token, instance <> <<0x00>>} + end + end + + defp encode_threadid(_opts) do + pid_serial = + self() + |> inspect() + |> String.split(".") + |> Enum.at(1) + |> Integer.parse() + |> elem(0) + + {@threadid_token, <>} + end + + defp encode_mars(_opts) do + {@mars_token, <<0x00>>} + end + + defp encode_fedauth(_opts) do + {@fedauth_token, <<0x01>>} + end + + # DECODE + @spec decode(iodata(), state()) :: + {:encrypt, state()} + | {:login, state()} + | {:disconnect, Tds.Error.t(), state()} + def decode(packet_data, %{opts: opts} = s) do + ecrypt = Keyword.get(opts, :ssl, false) + + {:ok, %{encryption: encryption, instance: instance}} = + packet_data + |> IO.iodata_to_binary() + |> decode_tokens([], s) + + case {ecrypt, encryption, instance} do + {_, _, false} -> + msg = "Connection terminated, connected instance is not '#{instance}'!" + disconnect(msg, s) + + {false, enc, _} when enc in [<<0x00>>, <<0x02>>] -> + {:login, s} + + {false, <<0x03>>, _} -> + disconnect("Server does not allow the requested encryption level.", s) + + {true, <<0x00>>, _} -> + disconnect("Server does not allow the requested encryption level.", s) + + {true, <<0x03>>, _} -> + disconnect("Server does not allow the requested encryption level.", s) + + {_, _, _} -> + Logger.debug("Upgrading connection to SSL/TSL.") + {:encrypt, s} + end + end + + defp decode_tokens( + <<@version_token, offset::ushort, length::ushort, tail::binary>>, + tokens, + s + ) do + tokens = [{:version, offset, length} | tokens] + decode_tokens(tail, tokens, s) + end + + defp decode_tokens( + <<@encryption_token, offset::ushort, length::ushort, tail::binary>>, + tokens, + s + ) do + tokens = [{:encryption, offset, length} | tokens] + decode_tokens(tail, tokens, s) + end + + defp decode_tokens( + <<@instopt_token, offset::ushort, length::ushort, tail::binary>>, + tokens, + s + ) do + tokens = [{:encryption, offset, length} | tokens] + decode_tokens(tail, tokens, s) + end + + defp decode_tokens( + <<@threadid_token, offset::ushort, length::ushort, tail::binary>>, + tokens, + s + ) do + tokens = [{:threadid, offset, length} | tokens] + decode_tokens(tail, tokens, s) + end + + defp decode_tokens( + <<@mars_token, offset::ushort, length::ushort, tail::binary>>, + tokens, + s + ) do + tokens = [{:mars, offset, length} | tokens] + decode_tokens(tail, tokens, s) + end + + defp decode_tokens( + <<@fedauth_token, offset::ushort, length::ushort, tail::binary>>, + tokens, + s + ) do + tokens = [{:fedauth, offset, length} | tokens] + decode_tokens(tail, tokens, s) + end + + defp decode_tokens( + <<@nonceopt_token, offset::ushort, length::ushort, tail::binary>>, + tokens, + s + ) do + tokens = [{:nonceopt, offset, length} | tokens] + decode_tokens(tail, tokens, s) + end + + defp decode_tokens( + <<@termintator_token, tail::binary>>, + tokens, + _s + ) do + {:ok, decode_data(Enum.reverse(tokens), tail, %__MODULE__{})} + end + + defp decode_data([], _, result), do: result + + defp decode_data([{key, _, length} | tokens], bin, m) do + <> = bin + + case key do + :version -> + <> = data + + decode_data( + tokens, + tail, + %{m | version: {major, minor, patch, trivial, subbuild}} + ) + + :encryption -> + decode_data( + tokens, + tail, + %{m | encryption: data} + ) + + :instance -> + decode_data( + tokens, + tail, + %{m | instance: data == <<0x00>>} + ) + + # :threadid -> + # :mars -> + # :fedauth -> + # :nonceopt -> + _ -> + decode_data(tokens, tail, m) + end + end + + defp disconnect(message, s) do + {:disconnect, Tds.Error.exception(message), s} + end +end diff --git a/lib/tds/tls.ex b/lib/tds/tls.ex new file mode 100644 index 0000000..43d78ed --- /dev/null +++ b/lib/tds/tls.ex @@ -0,0 +1,218 @@ +defmodule Tds.Tls do + @moduledoc false + require Logger + use GenServer + import Kernel, except: [send: 2] + import Tds.BinaryUtils + + defstruct [:socket, :ssl_opts, :owner_pid, :handshake, :buffer] + + def connect(socket, ssl_opts) do + ssl_opts = + ssl_opts ++ + [ + active: false, + cb_info: {Tds.Tls, :tcp, :tcp_closed, :tcp_error} + ] + + :inet.setopts(socket, active: false) + + with {:ok, pid} <- GenServer.start_link(__MODULE__, {socket, ssl_opts}, []), + :ok <- :gen_tcp.controlling_process(socket, pid) do + res = :ssl.connect(socket, ssl_opts, :infinity) + # todo: remove this line and handle it when server respond with 0x12 message with status 0x01 + GenServer.cast(pid, :handshake_complete) + res + else + error -> error + end + end + + def controlling_process(socket, tls_conn_pid) do + {:connected, pid} = Port.info(socket, :connected) + GenServer.call(pid, {:controlling_process, tls_conn_pid}) + end + + def send(socket, payload) do + {:connected, pid} = Port.info(socket, :connected) + GenServer.call(pid, {:send, payload}) + end + + def recv(socket, length, timeout \\ :infinity) do + {:connected, pid} = Port.info(socket, :connected) + GenServer.call(pid, {:recv, length, timeout}, timeout) + end + + defdelegate getopts(port, options), to: :inet + + # defdelegate setopts(socket, options), to: :inet + def setopts(socket, options) do + {:connected, pid} = Port.info(socket, :connected) + GenServer.call(pid, {:setopts, options}) + end + + defdelegate peername(socket), to: :inet + + :exports + |> :gen_tcp.module_info() + |> Enum.reject(fn {fun, _} -> + fun in [:send, :recv, :module_info, :controlling_process] + end) + |> Enum.each(fn + {name, 0} -> + defdelegate unquote(name)(), to: :gen_tcp + + {name, 1} -> + defdelegate unquote(name)(arg1), to: :gen_tcp + + {name, 2} -> + defdelegate unquote(name)(arg1, arg2), to: :gen_tcp + + {name, 3} -> + defdelegate unquote(name)(arg1, arg2, arg3), to: :gen_tcp + + {name, 4} -> + defdelegate unquote(name)(arg1, arg2, arg3, arg4), to: :gen_tcp + end) + + # SERVER + def init({socket, ssl_opts}) do + {:ok, %__MODULE__{socket: socket, ssl_opts: ssl_opts, handshake: true}} + end + + def handle_call({:controlling_process, tls_conn_pid}, _from, s) do + {:reply, :ok, %{s | owner_pid: tls_conn_pid}} + end + + def handle_call( + {:setopts, options}, + _from, + %{socket: socket, handshake: hs} = s + ) do + tds_header_size = if hs == true, do: 8, else: 0 + + opts = + options + |> Enum.map(fn + {:active, val} when is_number(val) -> {:active, val + tds_header_size} + val -> val + end) + + {:reply, :inet.setopts(socket, opts), s} + end + + def handle_call({:send, data}, _from, %{socket: socket, handshake: true} = s) do + size = IO.iodata_length(data) + 8 + + header = + <<0x12, 0x01, size::unsigned-size(2)-unit(8), 0x00, 0x00, 0x00, 0x00>> + + resp = :gen_tcp.send(socket, [header, data]) + {:reply, resp, s} + end + + def handle_call({:send, data}, _from, %{socket: socket, handshake: false} = s) do + resp = :gen_tcp.send(socket, data) + {:reply, resp, s} + end + + # def handle_call({:recv, length, timeout}, _from, %{socket: socket, handshake: true} = s) do + # res = case :gen_tcp.recv(socket, length, timeout) do + # {:ok, data} + # end + # {:reply, res, s} + # end + + def handle_call({:recv, length, timeout}, _from, %{socket: socket} = s) do + res = :gen_tcp.recv(socket, length, timeout) + {:reply, res, s} + end + + def handle_cast(:handshake_complete, s) do + {:noreply, %{s | handshake: false}} + end + + def handle_info( + {:tcp, _, _} = msg, + %{owner_pid: pid, handshake: false, buffer: nil} = s + ) do + Kernel.send(pid, msg) + {:noreply, s} + end + + def handle_info( + {:tcp, port, <<0x12, 0, size::unsigned-16, _::32, tail::binary>>}, + %{socket: socket, owner_pid: pid, buffer: nil, handshake: true} = s + ) do + expecting = size - 8 + + case tail do + <> -> + Kernel.send(pid, {:tcp, socket, ssl_payload}) + handle_info({:tcp, port, next_packet}, %{s | buffer: nil}) + + next_slice -> + state = %{s | buffer: {next_slice, expecting}} + {:noreply, state} + end + end + + def handle_info( + {:tcp, port, <<0x12, 1, size::unsigned-16, _::32, tail::binary>>}, + %{socket: socket, owner_pid: pid, buffer: nil, handshake: true} = s + ) do + expecting = size - 8 + + case tail do + <> -> + Kernel.send(pid, {:tcp, socket, ssl_payload}) + handle_info({:tcp, port, next_packet}, %{s | buffer: nil}) + + next_slice -> + state = %{s | buffer: {next_slice, expecting}} + {:noreply, state} + end + end + + def handle_info( + {:tcp, port, bin}, + %{ + socket: socket, + owner_pid: pid, + buffer: {slice, expecting}, + handshake: true + } = s + ) do + case IO.iodata_to_binary([slice, bin]) do + <> -> + Kernel.send(pid, {:tcp, socket, ssl_payload}) + handle_info({:tcp, port, next_packet}, %{s | buffer: nil}) + + next_slice -> + state = %{s | buffer: {next_slice, expecting}} + {:noreply, state} + end + end + + def handle_info( + {:tcp, _, _} = msg, + %{owner_pid: pid, handshake: true, buffer: nil} = s + ) do + Kernel.send(pid, msg) + {:noreply, s} + end + + def handle_info({tag, _} = msg, %{owner_pid: pid} = s) + when tag in [:tcp_closed, :ssl_closed] do + # todo + send(pid, msg) + {:stop, tag, s} + end + + def handle_info({tag, _, _} = msg, %{owner_pid: pid} = s) + when tag in [:tcp_error, :ssl_error] do + # todo + send(pid, msg) + {:stop, tag, s} + end +end diff --git a/lib/tds/versions.ex b/lib/tds/versions.ex index 0852275..2396591 100644 --- a/lib/tds/versions.ex +++ b/lib/tds/versions.ex @@ -1,25 +1,28 @@ defmodule Tds.Version do import Tds.Protocol.Grammar - defstruct version: 0x74000004, str_version: "7.4" + @default_version :v7_4 + @default_code 0x74000004 + + defstruct code: @default_code, version: @default_version @versions [ - {0x71000001, "7.1"}, - {0x72090002, "7.2"}, - {0x730A0003, "7.3.A"}, - {0x730B0003, "7.3.B"}, - {0x74000004, "7.4"} + {0x71000001, :v7_1}, + {0x72090002, :v7_2}, + {0x730A0003, :v7_3_a}, + {0x730B0003, :v7_3_b}, + {0x74000004, :v7_4} ] def decode(<>) do @versions - |> List.keyfind(key, 0, "7.4") + |> List.keyfind(key, 0, @default_version) end def encode(ver) do val = @versions - |> List.keyfind(ver, 1, 0x74000004) + |> List.keyfind(ver, 1, @default_code) <> end diff --git a/mix.exs b/mix.exs index 33ecd40..86f1c6e 100644 --- a/mix.exs +++ b/mix.exs @@ -39,7 +39,7 @@ defmodule Tds.Mixfile do def application do [ - extra_applications: [:logger, :db_connection, :decimal], + extra_applications: [:logger, :db_connection, :decimal, :inets, :ssl], env: [ json_library: Jason ] diff --git a/test/login_test.exs b/test/login_test.exs index 05bd794..3b4e2e0 100644 --- a/test/login_test.exs +++ b/test/login_test.exs @@ -1,14 +1,72 @@ defmodule LoginTest do use ExUnit.Case, async: true + import ExUnit.CaptureLog - test "Login with sql server authentication" do - # :dbg.tracer() - # :dbg.p(:all,:c) - # :dbg.tpl(Tds.Messages,:parse,:x) - # :dbg.tpl(Tds.Protocol,:message,:x) - opts = Application.fetch_env!(:tds, :opts) + setup do + hostname = + Application.fetch_env!(:tds, :opts) + |> Keyword.get(:hostname) + {:ok, + [ + options: [ + hostname: hostname, + database: "test", + backoff_type: :stop, + max_restarts: 0, + show_sensitive_data_on_connection_error: true + ] + ]} + end + + @tag :login + test "login with sql server authentication", context do + opts = Application.fetch_env!(:tds, :opts) ++ context[:options] {:ok, pid} = Tds.start_link(opts) - assert {:ok, _} = Tds.query(pid, "SELECT 1", []) + assert {:ok, %Tds.Result{}} = Tds.query(pid, "SELECT 1", []) + end + + @tag :login + test "login with non existing sql server authentication", context do + assert capture_log(fn -> + opts = [username: "sa", password: "wrong"] ++ context[:options] + assert_start_and_killed(opts) + end) =~ + "(Tds.Error) Line 1 (Error 18456): Login failed for user 'sa'" + end + + @tag :login + @tag :tls + test "login with valid sql login over tsl", context do + opts = + Application.fetch_env!(:tds, :opts) ++ + [ssl: true, ssl_opts: []] + # [ssl: true, ssl_opts: [log_debug: true, log_level: :debug]] + + assert {:ok, pid} = Tds.start_link(opts ++ context[:options]) + assert {:ok, %Tds.Result{}} = Tds.query(pid, "SELECT 1", []) + end + + @tag :login + @tag :tls + test "login with non existing sql server authentication over tls", context do + assert capture_log(fn -> + opts = + [username: "sa", password: "wrong"] ++ + context[:options] ++ + [ssl: true, ssl_opts: [log_debug: true]] + + assert_start_and_killed(opts) + end) =~ + "(Tds.Error) Line 1 (Error 18456): Login failed for user 'sa'" + end + + defp assert_start_and_killed(opts) do + Process.flag(:trap_exit, true) + + case Tds.start_link(opts) do + {:ok, pid} -> assert_receive {:EXIT, ^pid, :killed}, 1_000 + {:error, :killed} -> :ok + end end end diff --git a/test/plp_test.exs b/test/plp_test.exs index c8b03ad..11bb63d 100644 --- a/test/plp_test.exs +++ b/test/plp_test.exs @@ -101,6 +101,7 @@ defmodule PLPTest do "", ] |> Enum.join("") + :ok = query( """ diff --git a/test/test_helper.exs b/test/test_helper.exs index 9962b8a..8851893 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -89,34 +89,34 @@ end opts = Application.get_env(:tds, :opts) database = opts[:database] - -{"", 0} = - Tds.TestHelper.sqlcmd(opts, """ - IF EXISTS(SELECT * FROM sys.databases where name = '#{database}') - BEGIN - DROP DATABASE [#{database}]; - END; - CREATE DATABASE [#{database}]; - """) - -{"Changed database context to 'test'." <> _, 0} = - Tds.TestHelper.sqlcmd(opts, """ - USE [test]; - - CREATE TABLE altering ([a] int) - - CREATE TABLE [composite1] ([a] int, [b] text); - CREATE TABLE [composite2] ([a] int, [b] int, [c] int); - CREATE TABLE [uniques] ([id] int NOT NULL, CONSTRAINT UIX_uniques_id UNIQUE([id])) - """) - -{"Changed database context to 'test'." <> _, 0} = - Tds.TestHelper.sqlcmd(opts, """ - USE test - GO - CREATE SCHEMA test; - """) - +if System.get_env("TEST_AZURE") == nil do + {"", 0} = + Tds.TestHelper.sqlcmd(opts, """ + IF EXISTS(SELECT * FROM sys.databases where name = '#{database}') + BEGIN + DROP DATABASE [#{database}]; + END; + CREATE DATABASE [#{database}]; + """) + + {"Changed database context to 'test'." <> _, 0} = + Tds.TestHelper.sqlcmd(opts, """ + USE [test]; + + CREATE TABLE altering ([a] int) + + CREATE TABLE [composite1] ([a] int, [b] text); + CREATE TABLE [composite2] ([a] int, [b] int, [c] int); + CREATE TABLE [uniques] ([id] int NOT NULL, CONSTRAINT UIX_uniques_id UNIQUE([id])) + """) + + {"Changed database context to 'test'." <> _, 0} = + Tds.TestHelper.sqlcmd(opts, """ + USE test + GO + CREATE SCHEMA test; + """) +end # :dbg.start() # :dbg.tracer() # :dbg.p(:all,:c)