@@ -712,12 +712,12 @@ def test__get_metadata_ip_root_no_mtls():
712712 assert _metadata ._get_metadata_ip_root (use_mtls = False ) == "http://169.254.169.254"
713713
714714
715- @mock .patch ("google.auth.compute_engine._mtls.create_session " )
716- def test__prepare_request_for_mds_mtls (mock_create_session ):
717- request = mock .Mock ( )
718- new_request = _metadata ._prepare_request_for_mds (request , use_mtls = True )
719- mock_create_session .assert_called_once ()
720- assert isinstance ( new_request , google_auth_requests . Request )
715+ @mock .patch ("google.auth.compute_engine._mtls.MdsMtlsAdapter " )
716+ def test__prepare_request_for_mds_mtls (mock_mds_mtls_adapter ):
717+ request = google_auth_requests . Request ( mock .create_autospec ( requests . Session ) )
718+ _metadata ._prepare_request_for_mds (request , use_mtls = True )
719+ mock_mds_mtls_adapter .assert_called_once ()
720+ assert request . session . mount . call_count == len ( _metadata . _GCE_DEFAULT_MDS_HOSTS )
721721
722722
723723def test__prepare_request_for_mds_no_mtls ():
@@ -726,53 +726,100 @@ def test__prepare_request_for_mds_no_mtls():
726726 assert new_request is request
727727
728728
729- @mock .patch ("google.auth.compute_engine._mtls.should_use_mds_mtls" , return_value = True )
730- @mock .patch ("google.auth.compute_engine._mtls.create_session" )
731729@mock .patch ("google.auth.metrics.mds_ping" , return_value = MDS_PING_METRICS_HEADER_VALUE )
730+ @mock .patch ("google.auth.compute_engine._mtls.MdsMtlsAdapter" )
731+ @mock .patch ("google.auth.compute_engine._mtls.should_use_mds_mtls" , return_value = True )
732+ @mock .patch ("google.auth.transport.requests.Request" )
732733def test_ping_mtls (
733- mock_metrics_header_value , mock_create_session , mock_should_use_mtls
734+ mock_request , mock_should_use_mtls , mock_mds_mtls_adapter , mock_metrics_header_value
734735):
735- response = mock .create_autospec (requests .Response , instance = True )
736- response .status_code = http_client .OK
736+ response = mock .create_autospec (transport .Response , instance = True )
737+ response .status = http_client .OK
737738 response .headers = _metadata ._METADATA_HEADERS
738- mock_session = mock .Mock ()
739- mock_session .request .return_value = response
740- mock_create_session .return_value = mock_session
739+ mock_request .return_value = response
741740
742- initial_request = mock .Mock ()
743- assert _metadata .ping (initial_request )
741+ assert _metadata .ping (mock_request )
744742
745743 mock_should_use_mtls .assert_called_once ()
746- mock_create_session .assert_called_once ()
747- mock_session . request .assert_called_once_with (
748- "GET " ,
749- "https://169.254.169.254 " ,
744+ mock_mds_mtls_adapter .assert_called_once ()
745+ mock_request .assert_called_once_with (
746+ url = "https://169.254.169.254 " ,
747+ method = "GET " ,
750748 headers = MDS_PING_REQUEST_HEADER ,
751749 timeout = _metadata ._METADATA_DEFAULT_TIMEOUT ,
752- data = None ,
753750 )
754751
755752
753+ @mock .patch ("google.auth.compute_engine._mtls.MdsMtlsAdapter" )
756754@mock .patch ("google.auth.compute_engine._mtls.should_use_mds_mtls" , return_value = True )
757- @mock .patch ("google.auth.compute_engine._mtls.create_session " )
758- def test_get_mtls (mock_create_session , mock_should_use_mtls ):
759- response = mock .create_autospec (requests .Response , instance = True )
760- response .status_code = http_client .OK
761- response .content = _helpers .to_bytes ("{}" )
755+ @mock .patch ("google.auth.transport.requests.Request " )
756+ def test_get_mtls (mock_request , mock_should_use_mtls , mock_mds_mtls_adapter ):
757+ response = mock .create_autospec (transport .Response , instance = True )
758+ response .status = http_client .OK
759+ response .data = _helpers .to_bytes ("{}" )
762760 response .headers = {"content-type" : "application/json" }
763- mock_session = mock .Mock ()
764- mock_session .request .return_value = response
765- mock_create_session .return_value = mock_session
761+ mock_request .return_value = response
766762
767- initial_request = mock .Mock ()
768- _metadata .get (initial_request , "some/path" )
763+ _metadata .get (mock_request , "some/path" )
769764
770765 mock_should_use_mtls .assert_called_once ()
771- mock_create_session .assert_called_once ()
772- mock_session .request .assert_called_once_with (
773- "GET" ,
774- "https://metadata.google.internal/computeMetadata/v1/some/path" ,
775- data = None ,
766+ mock_mds_mtls_adapter .assert_called_once ()
767+ mock_request .assert_called_once_with (
768+ url = "https://metadata.google.internal/computeMetadata/v1/some/path" ,
769+ method = "GET" ,
776770 headers = _metadata ._METADATA_HEADERS ,
777771 timeout = _metadata ._METADATA_DEFAULT_TIMEOUT ,
778772 )
773+
774+
775+ @pytest .mark .parametrize (
776+ "mds_mode, metadata_host, expect_exception" ,
777+ [
778+ (_metadata ._mtls .MdsMtlsMode .STRICT , _metadata ._GCE_DEFAULT_HOST , False ),
779+ (_metadata ._mtls .MdsMtlsMode .STRICT , "custom.host" , True ),
780+ (_metadata ._mtls .MdsMtlsMode .NONE , "custom.host" , False ),
781+ (_metadata ._mtls .MdsMtlsMode .DEFAULT , _metadata ._GCE_DEFAULT_HOST , False ),
782+ ],
783+ )
784+ @mock .patch ("google.auth.compute_engine._mtls._parse_mds_mode" )
785+ def test_validate_gce_mds_configured_environment (
786+ mock_parse_mds_mode , mds_mode , metadata_host , expect_exception
787+ ):
788+ mock_parse_mds_mode .return_value = mds_mode
789+ with mock .patch (
790+ "google.auth.compute_engine._metadata._GCE_METADATA_HOST" , new = metadata_host
791+ ):
792+ if expect_exception :
793+ with pytest .raises (exceptions .MutualTLSChannelError ):
794+ _metadata ._validate_gce_mds_configured_environment ()
795+ else :
796+ _metadata ._validate_gce_mds_configured_environment ()
797+ mock_parse_mds_mode .assert_called_once ()
798+
799+
800+ @mock .patch ("google.auth.compute_engine._mtls.MdsMtlsAdapter" )
801+ def test__prepare_request_for_mds_mtls_session_exists (mock_mds_mtls_adapter ):
802+ mock_session = mock .create_autospec (requests .Session )
803+ request = google_auth_requests .Request (mock_session )
804+ new_request = _metadata ._prepare_request_for_mds (request , use_mtls = True )
805+
806+ mock_mds_mtls_adapter .assert_called_once ()
807+ assert mock_session .mount .call_count == len (_metadata ._GCE_DEFAULT_MDS_HOSTS )
808+ assert new_request is request
809+
810+
811+ @mock .patch ("google.auth.compute_engine._mtls.MdsMtlsAdapter" )
812+ def test__prepare_request_for_mds_mtls_no_session (mock_mds_mtls_adapter ):
813+ request = google_auth_requests .Request (None )
814+ # Explicitly set session to None to avoid a session being created in the Request constructor.
815+ request .session = None
816+
817+ with mock .patch ("requests.Session" ) as mock_session_class :
818+ new_request = _metadata ._prepare_request_for_mds (request , use_mtls = True )
819+
820+ mock_session_class .assert_called_once ()
821+ mock_mds_mtls_adapter .assert_called_once ()
822+ assert new_request .session .mount .call_count == len (
823+ _metadata ._GCE_DEFAULT_MDS_HOSTS
824+ )
825+ assert new_request is request
0 commit comments