Skip to content

Commit 39eb287

Browse files
authored
feat: Add custom tls signer for ECP Provider. (#1402)
feat: Add custom tls signer for ECP Provider.
1 parent 9b46ee3 commit 39eb287

File tree

5 files changed

+131
-68
lines changed

5 files changed

+131
-68
lines changed

google/auth/transport/_custom_tls_signer.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,22 @@ def load_signer_lib(signer_lib_path):
107107
return lib
108108

109109

110+
def load_provider_lib(provider_lib_path):
111+
_LOGGER.debug("loading provider library from %s", provider_lib_path)
112+
113+
# winmode parameter is only available for python 3.8+.
114+
lib = (
115+
ctypes.CDLL(provider_lib_path, winmode=0)
116+
if sys.version_info >= (3, 8) and os.name == "nt"
117+
else ctypes.CDLL(provider_lib_path)
118+
)
119+
120+
lib.ECP_attach_to_ctx.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
121+
lib.ECP_attach_to_ctx.restype = ctypes.c_int
122+
123+
return lib
124+
125+
110126
# Computes SHA256 hash.
111127
def _compute_sha256_digest(to_be_signed, to_be_signed_len):
112128
from cryptography.hazmat.primitives import hashes
@@ -199,21 +215,31 @@ def __init__(self, enterprise_cert_file_path):
199215
self._enterprise_cert_file_path = enterprise_cert_file_path
200216
self._cert = None
201217
self._sign_callback = None
218+
self._provider_lib = None
202219

203220
def load_libraries(self):
204-
try:
205-
with open(self._enterprise_cert_file_path, "r") as f:
206-
enterprise_cert_json = json.load(f)
207-
libs = enterprise_cert_json["libs"]
208-
signer_library = libs["ecp_client"]
209-
offload_library = libs["tls_offload"]
210-
except (KeyError, ValueError) as caught_exc:
211-
new_exc = exceptions.MutualTLSChannelError(
212-
"enterprise cert file is invalid", caught_exc
213-
)
214-
raise new_exc from caught_exc
215-
self._offload_lib = load_offload_lib(offload_library)
216-
self._signer_lib = load_signer_lib(signer_library)
221+
with open(self._enterprise_cert_file_path, "r") as f:
222+
enterprise_cert_json = json.load(f)
223+
libs = enterprise_cert_json.get("libs", {})
224+
225+
signer_library = libs.get("ecp_client", None)
226+
offload_library = libs.get("tls_offload", None)
227+
provider_library = libs.get("ecp_provider", None)
228+
229+
# Using newer provider implementation. This is mutually exclusive to the
230+
# offload implementation.
231+
if provider_library:
232+
self._provider_lib = load_provider_lib(provider_library)
233+
return
234+
235+
# Using old offload implementation
236+
if offload_library and signer_library:
237+
self._offload_lib = load_offload_lib(offload_library)
238+
self._signer_lib = load_signer_lib(signer_library)
239+
self.set_up_custom_key()
240+
return
241+
242+
raise exceptions.MutualTLSChannelError("enterprise cert file is invalid")
217243

218244
def set_up_custom_key(self):
219245
# We need to keep a reference of the cert and sign callback so it won't
@@ -224,11 +250,22 @@ def set_up_custom_key(self):
224250
)
225251

226252
def attach_to_ssl_context(self, ctx):
227-
# In the TLS handshake, the signing operation will be done by the
228-
# sign_callback.
229-
if not self._offload_lib.ConfigureSslContext(
230-
self._sign_callback,
231-
ctypes.c_char_p(self._cert),
232-
_cast_ssl_ctx_to_void_p(ctx._ctx._context),
233-
):
234-
raise exceptions.MutualTLSChannelError("failed to configure SSL context")
253+
if self._provider_lib:
254+
if not self._provider_lib.ECP_attach_to_ctx(
255+
_cast_ssl_ctx_to_void_p(ctx._ctx._context),
256+
self._enterprise_cert_file_path.encode("ascii"),
257+
):
258+
raise exceptions.MutualTLSChannelError(
259+
"failed to configure ECP Provider SSL context"
260+
)
261+
elif self._offload_lib and self._signer_lib:
262+
if not self._offload_lib.ConfigureSslContext(
263+
self._sign_callback,
264+
ctypes.c_char_p(self._cert),
265+
_cast_ssl_ctx_to_void_p(ctx._ctx._context),
266+
):
267+
raise exceptions.MutualTLSChannelError(
268+
"failed to configure ECP Offload SSL context"
269+
)
270+
else:
271+
raise exceptions.MutualTLSChannelError("Invalid ECP configuration.")

google/auth/transport/requests.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,6 @@ def __init__(self, enterprise_cert_file_path):
274274

275275
self.signer = _custom_tls_signer.CustomTlsSigner(enterprise_cert_file_path)
276276
self.signer.load_libraries()
277-
self.signer.set_up_custom_key()
278277

279278
poolmanager = create_urllib3_context()
280279
poolmanager.load_verify_locations(cafile=certifi.where())
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"libs": {
3+
"ecp_client": "/path/to/signer/lib",
4+
"ecp_provider": "/path/to/provider/lib"
5+
}
6+
}

tests/transport/test__custom_tls_signer.py

Lines changed: 67 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
import base64
1615
import ctypes
1716
import os
@@ -30,11 +29,19 @@
3029
ENTERPRISE_CERT_FILE = os.path.join(
3130
os.path.dirname(__file__), "../data/enterprise_cert_valid.json"
3231
)
32+
ENTERPRISE_CERT_FILE_PROVIDER = os.path.join(
33+
os.path.dirname(__file__), "../data/enterprise_cert_valid_provider.json"
34+
)
3335
INVALID_ENTERPRISE_CERT_FILE = os.path.join(
3436
os.path.dirname(__file__), "../data/enterprise_cert_invalid.json"
3537
)
3638

3739

40+
def test_load_provider_lib():
41+
with mock.patch("ctypes.CDLL", return_value=mock.MagicMock()):
42+
_custom_tls_signer.load_provider_lib("/path/to/provider/lib")
43+
44+
3845
def test_load_offload_lib():
3946
with mock.patch("ctypes.CDLL", return_value=mock.MagicMock()):
4047
lib = _custom_tls_signer.load_offload_lib("/path/to/offload/lib")
@@ -173,62 +180,81 @@ def test_custom_tls_signer():
173180
) as load_offload_lib:
174181
load_offload_lib.return_value = offload_lib
175182
load_signer_lib.return_value = signer_lib
176-
signer_object = _custom_tls_signer.CustomTlsSigner(ENTERPRISE_CERT_FILE)
177-
signer_object.load_libraries()
178-
assert signer_object._cert is None
183+
with mock.patch(
184+
"google.auth.transport._custom_tls_signer.get_cert"
185+
) as get_cert:
186+
with mock.patch(
187+
"google.auth.transport._custom_tls_signer.get_sign_callback"
188+
) as get_sign_callback:
189+
get_cert.return_value = b"mock_cert"
190+
signer_object = _custom_tls_signer.CustomTlsSigner(
191+
ENTERPRISE_CERT_FILE
192+
)
193+
signer_object.load_libraries()
194+
signer_object.attach_to_ssl_context(create_urllib3_context())
195+
get_cert.assert_called_once()
196+
get_sign_callback.assert_called_once()
197+
offload_lib.ConfigureSslContext.assert_called_once()
179198
assert signer_object._enterprise_cert_file_path == ENTERPRISE_CERT_FILE
180199
assert signer_object._offload_lib == offload_lib
181200
assert signer_object._signer_lib == signer_lib
182201
load_signer_lib.assert_called_with("/path/to/signer/lib")
183202
load_offload_lib.assert_called_with("/path/to/offload/lib")
184203

185-
# Test set_up_custom_key and set_up_ssl_context methods
186-
with mock.patch("google.auth.transport._custom_tls_signer.get_cert") as get_cert:
187-
with mock.patch(
188-
"google.auth.transport._custom_tls_signer.get_sign_callback"
189-
) as get_sign_callback:
190-
get_cert.return_value = b"mock_cert"
191-
signer_object.set_up_custom_key()
192-
signer_object.attach_to_ssl_context(create_urllib3_context())
193-
get_cert.assert_called_once()
194-
get_sign_callback.assert_called_once()
195-
offload_lib.ConfigureSslContext.assert_called_once()
196204

205+
def test_custom_tls_signer_provider():
206+
provider_lib = mock.MagicMock()
197207

198-
def test_custom_tls_signer_failed_to_load_libraries():
199208
# Test load_libraries method
209+
with mock.patch(
210+
"google.auth.transport._custom_tls_signer.load_provider_lib"
211+
) as load_provider_lib:
212+
load_provider_lib.return_value = provider_lib
213+
signer_object = _custom_tls_signer.CustomTlsSigner(
214+
ENTERPRISE_CERT_FILE_PROVIDER
215+
)
216+
signer_object.load_libraries()
217+
signer_object.attach_to_ssl_context(mock.MagicMock())
218+
219+
assert signer_object._enterprise_cert_file_path == ENTERPRISE_CERT_FILE_PROVIDER
220+
assert signer_object._provider_lib == provider_lib
221+
load_provider_lib.assert_called_with("/path/to/provider/lib")
222+
223+
224+
def test_custom_tls_signer_failed_to_load_libraries():
200225
with pytest.raises(exceptions.MutualTLSChannelError) as excinfo:
201226
signer_object = _custom_tls_signer.CustomTlsSigner(INVALID_ENTERPRISE_CERT_FILE)
202227
signer_object.load_libraries()
203228
assert excinfo.match("enterprise cert file is invalid")
204229

205230

206-
def test_custom_tls_signer_fail_to_offload():
207-
offload_lib = mock.MagicMock()
208-
signer_lib = mock.MagicMock()
231+
def test_custom_tls_signer_failed_to_attach():
232+
with pytest.raises(exceptions.MutualTLSChannelError) as excinfo:
233+
signer_object = _custom_tls_signer.CustomTlsSigner(ENTERPRISE_CERT_FILE)
234+
signer_object._offload_lib = mock.MagicMock()
235+
signer_object._signer_lib = mock.MagicMock()
236+
signer_object._sign_callback = mock.MagicMock()
237+
signer_object._cert = b"mock cert"
238+
signer_object._offload_lib.ConfigureSslContext.return_value = False
239+
signer_object.attach_to_ssl_context(mock.MagicMock())
240+
assert excinfo.match("failed to configure ECP Offload SSL context")
209241

210-
with mock.patch(
211-
"google.auth.transport._custom_tls_signer.load_signer_lib"
212-
) as load_signer_lib:
213-
with mock.patch(
214-
"google.auth.transport._custom_tls_signer.load_offload_lib"
215-
) as load_offload_lib:
216-
load_offload_lib.return_value = offload_lib
217-
load_signer_lib.return_value = signer_lib
218-
signer_object = _custom_tls_signer.CustomTlsSigner(ENTERPRISE_CERT_FILE)
219-
signer_object.load_libraries()
220242

221-
# set the return value to be 0 which indicts offload fails
222-
offload_lib.ConfigureSslContext.return_value = 0
243+
def test_custom_tls_signer_failed_to_attach_provider():
244+
with pytest.raises(exceptions.MutualTLSChannelError) as excinfo:
245+
signer_object = _custom_tls_signer.CustomTlsSigner(
246+
ENTERPRISE_CERT_FILE_PROVIDER
247+
)
248+
signer_object._provider_lib = mock.MagicMock()
249+
signer_object._provider_lib.ECP_attach_to_ctx.return_value = False
250+
signer_object.attach_to_ssl_context(mock.MagicMock())
251+
assert excinfo.match("failed to configure ECP Provider SSL context")
223252

253+
254+
def test_custom_tls_signer_failed_to_attach_no_libs():
224255
with pytest.raises(exceptions.MutualTLSChannelError) as excinfo:
225-
with mock.patch(
226-
"google.auth.transport._custom_tls_signer.get_cert"
227-
) as get_cert:
228-
with mock.patch(
229-
"google.auth.transport._custom_tls_signer.get_sign_callback"
230-
):
231-
get_cert.return_value = b"mock_cert"
232-
signer_object.set_up_custom_key()
233-
signer_object.attach_to_ssl_context(create_urllib3_context())
234-
assert excinfo.match("failed to configure SSL context")
256+
signer_object = _custom_tls_signer.CustomTlsSigner(ENTERPRISE_CERT_FILE)
257+
signer_object._offload_lib = None
258+
signer_object._signer_lib = None
259+
signer_object.attach_to_ssl_context(mock.MagicMock())
260+
assert excinfo.match("Invalid ECP configuration.")

tests/transport/test_requests.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -544,17 +544,13 @@ class TestMutualTlsOffloadAdapter(object):
544544
@mock.patch.object(
545545
google.auth.transport._custom_tls_signer.CustomTlsSigner, "load_libraries"
546546
)
547-
@mock.patch.object(
548-
google.auth.transport._custom_tls_signer.CustomTlsSigner, "set_up_custom_key"
549-
)
550547
@mock.patch.object(
551548
google.auth.transport._custom_tls_signer.CustomTlsSigner,
552549
"attach_to_ssl_context",
553550
)
554551
def test_success(
555552
self,
556553
mock_attach_to_ssl_context,
557-
mock_set_up_custom_key,
558554
mock_load_libraries,
559555
mock_proxy_manager_for,
560556
mock_init_poolmanager,
@@ -565,7 +561,6 @@ def test_success(
565561
)
566562

567563
mock_load_libraries.assert_called_once()
568-
mock_set_up_custom_key.assert_called_once()
569564
assert mock_attach_to_ssl_context.call_count == 2
570565

571566
adapter.init_poolmanager()

0 commit comments

Comments
 (0)