@@ -314,12 +314,11 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]:
314314 # detect Azure AD Tenant ID if it's not specified directly
315315 token_endpoint = cfg .oidc_endpoints .token_endpoint
316316 cfg .azure_tenant_id = token_endpoint .replace (aad_endpoint , '' ).split ('/' )[0 ]
317- inner = ClientCredentials (
318- client_id = cfg .azure_client_id ,
319- client_secret = "" , # we have no (rotatable) secrets in OIDC flow
320- token_url = f"{ aad_endpoint } { cfg .azure_tenant_id } /oauth2/token" ,
321- endpoint_params = params ,
322- use_params = True )
317+ inner = ClientCredentials (client_id = cfg .azure_client_id ,
318+ client_secret = "" , # we have no (rotatable) secrets in OIDC flow
319+ token_url = f"{ aad_endpoint } { cfg .azure_tenant_id } /oauth2/token" ,
320+ endpoint_params = params ,
321+ use_params = True )
323322
324323 def refreshed_headers () -> Dict [str , str ]:
325324 token = inner .token ()
@@ -725,11 +724,10 @@ def inner() -> Dict[str, str]:
725724# https://github.com/mlflow/mlflow/blob/1219e3ef1aac7d337a618a352cd859b336cf5c81/mlflow/legacy_databricks_cli/configure/provider.py#L332
726725class ModelServingAuthProvider ():
727726 USER_CREDENTIALS = "user_credentials"
728- EMBEDDED_CREDENTIALS = "embedded_credentials"
729727
730728 _MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH = "/var/credentials-secret/model-dependencies-oauth-token"
731729
732- def __init__ (self , credential_type ):
730+ def __init__ (self , credential_type : Optional [ str ] ):
733731 self .expiry_time = - 1
734732 self .current_token = None
735733 self .refresh_duration = 300 # 300 Seconds
@@ -746,7 +744,7 @@ def should_fetch_model_serving_environment_oauth() -> bool:
746744 return (is_in_model_serving_env == "true"
747745 and os .path .isfile (ModelServingAuthProvider ._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH ))
748746
749- def get_model_dependency_oauth_token (self , should_retry = True ) -> str :
747+ def _get_model_dependency_oauth_token (self , should_retry = True ) -> str :
750748 # Use Cached value if it is valid
751749 if self .current_token is not None and self .expiry_time > time .time ():
752750 return self .current_token
@@ -762,14 +760,14 @@ def get_model_dependency_oauth_token(self, should_retry=True) -> str:
762760 logger .warning ("Unable to read oauth token on first attmept in Model Serving Environment" ,
763761 exc_info = e )
764762 time .sleep (0.5 )
765- return self .get_model_dependency_oauth_token (should_retry = False )
763+ return self ._get_model_dependency_oauth_token (should_retry = False )
766764 else :
767765 raise RuntimeError (
768766 "Unable to read OAuth credentials from the file mounted in Databricks Model Serving"
769767 ) from e
770768 return self .current_token
771769
772- def get_invokers_token (self ):
770+ def _get_invokers_token (self ):
773771 current_thread = threading .current_thread ()
774772 thread_data = current_thread .__dict__
775773 invokers_token = None
@@ -788,18 +786,16 @@ def get_databricks_host_token(self) -> Optional[Tuple[str, str]]:
788786 # read from DB_MODEL_SERVING_HOST_ENV_VAR if available otherwise MODEL_SERVING_HOST_ENV_VAR
789787 host = os .environ .get ("DATABRICKS_MODEL_SERVING_HOST_URL" ) or os .environ .get (
790788 "DB_MODEL_SERVING_HOST_URL" )
791- token = self .get_model_dependency_oauth_token (
792- ) if self .credential_type == ModelServingAuthProvider .EMBEDDED_CREDENTIALS else self .get_invokers_token (
793- )
794789
795- return (host , token )
790+ if self .credential_type == ModelServingAuthProvider .USER_CREDENTIALS :
791+ return (host , self ._get_invokers_token ())
792+ else :
793+ return (host , self ._get_model_dependency_oauth_token ())
796794
797795
798- def model_serving_auth_func (cfg : 'Config' , credential_type ) -> Optional [CredentialsProvider ]:
796+ def model_serving_auth_visitor (cfg : 'Config' ,
797+ credential_type : Optional [str ] = None ) -> Optional [CredentialsProvider ]:
799798 try :
800- if not ModelServingAuthProvider .should_fetch_model_serving_environment_oauth ():
801- logger .debug ("model-serving: Not in Databricks Model Serving, skipping" )
802- return None
803799 model_serving_auth_provider = ModelServingAuthProvider (credential_type )
804800 host , token = model_serving_auth_provider .get_databricks_host_token ()
805801 if token is None :
@@ -823,7 +819,11 @@ def inner() -> Dict[str, str]:
823819
824820@credentials_strategy ('model-serving' , [])
825821def model_serving_auth (cfg : 'Config' ) -> Optional [CredentialsProvider ]:
826- return model_serving_auth_func (cfg , ModelServingAuthProvider .EMBEDDED_CREDENTIALS )
822+ if not ModelServingAuthProvider .should_fetch_model_serving_environment_oauth ():
823+ logger .debug ("model-serving: Not in Databricks Model Serving, skipping" )
824+ return None
825+
826+ return model_serving_auth_visitor (cfg )
827827
828828
829829class DefaultCredentials :
@@ -870,37 +870,25 @@ def __call__(self, cfg: 'Config') -> CredentialsProvider:
870870 )
871871
872872
873- class AgentCredentials (CredentialsStrategy ):
873+ class ModelServingUserCredentials (CredentialsStrategy ):
874874
875- def __init__ (self , credential_type ):
876- self .credential_type = credential_type
875+ def __init__ (self ):
876+ self .credential_type = ModelServingAuthProvider . USER_CREDENTIALS
877877 self .default_credentials = DefaultCredentials ()
878878
879879 def auth_type (self ):
880880 if ModelServingAuthProvider .should_fetch_model_serving_environment_oauth ():
881- return "agent_ " + self .credential_type
881+ return "model_serving_ " + self .credential_type
882882 else :
883883 return self .default_credentials .auth_type ()
884884
885885 def __call__ (self , cfg : 'Config' ) -> CredentialsProvider :
886886 if ModelServingAuthProvider .should_fetch_model_serving_environment_oauth ():
887- header_factory = model_serving_auth_func (cfg , self .credential_type )
887+ header_factory = model_serving_auth_visitor (cfg , self .credential_type )
888888 if not header_factory :
889889 raise ValueError (
890890 f"Unable to authenticate using { self .credential_type } in Databricks Model Serving Environment"
891891 )
892892 return header_factory
893893 else :
894894 return self .default_credentials (cfg )
895-
896-
897- class AgentUserCredentials (AgentCredentials ):
898-
899- def __init__ (self ):
900- super ().__init__ (ModelServingAuthProvider .USER_CREDENTIALS )
901-
902-
903- class AgentEmbeddedCredentials (AgentCredentials ):
904-
905- def __init__ (self ):
906- super ().__init__ (ModelServingAuthProvider .EMBEDDED_CREDENTIALS )
0 commit comments