diff --git a/src/flyte/cli/_get.py b/src/flyte/cli/_get.py index 8a5bc1a90..627ac778e 100644 --- a/src/flyte/cli/_get.py +++ b/src/flyte/cli/_get.py @@ -1,4 +1,5 @@ import asyncio +import time from typing import Tuple, Union, get_args import rich_click as click @@ -43,7 +44,12 @@ def project(cfg: common.CLIConfig, name: str | None = None): console = common.get_console() if name: - console.print(pretty_repr(remote.Project.get(name))) + for i in range(3): + start = time.time() + console.print(pretty_repr(remote.Project.get(name))) + end = time.time() + print(f"----- Time to fetch project {name}, attempt {i}: {end - start:.2f} seconds") + else: console.print(common.format("Projects", remote.Project.listall(), cfg.output_format)) diff --git a/src/flyte/remote/_client/auth/_authenticators/base.py b/src/flyte/remote/_client/auth/_authenticators/base.py index b50f70d79..139e762ee 100644 --- a/src/flyte/remote/_client/auth/_authenticators/base.py +++ b/src/flyte/remote/_client/auth/_authenticators/base.py @@ -1,6 +1,7 @@ import asyncio import dataclasses import ssl +import time import typing from abc import abstractmethod from http import HTTPStatus @@ -131,18 +132,23 @@ async def get_grpc_call_auth_metadata(self) -> typing.Optional[GrpcAuthMetadata] :return: A tuple of (header_key, header_value) or None if no credentials are available """ - creds = self.get_credentials() - if creds: - header_key = self._default_header_key - if self._resolved_config is not None: - # We only resolve the config during authentication flow, to avoid unnecessary network calls - # and usually the header_key is consistent. - header_key = self._resolved_config.header_key - return GrpcAuthMetadata( - creds_id=creds.id, - pairs=Metadata((header_key, f"Bearer {creds.access_token}")), - ) - return None + start = time.time() + try: + creds = self.get_credentials() + if creds: + header_key = self._default_header_key + if self._resolved_config is not None: + # We only resolve the config during authentication flow, to avoid unnecessary network calls + # and usually the header_key is consistent. + header_key = self._resolved_config.header_key + return GrpcAuthMetadata( + creds_id=creds.id, + pairs=Metadata((header_key, f"Bearer {creds.access_token}")), + ) + return None + finally: + end = time.time() + print(f"----- Time to get gRPC auth metadata: {end - start:.2f} seconds") async def refresh_credentials(self, creds_id: str | None = None): """ @@ -168,23 +174,28 @@ async def refresh_credentials(self, creds_id: str | None = None): # Credentials have been refreshed by another thread/coroutine since caller read them return - # Use the async lock to ensure coroutine safety - async with self._async_lock: - # Double-check pattern to avoid unnecessary work - if creds_id and creds_id != self._creds_id: - # Another thread/coroutine refreshed credentials while we were waiting for the lock - return - - # Perform the actual credential refresh - try: - self._creds = await self._do_refresh_credentials() - KeyringStore.store(self._creds) - except Exception: - KeyringStore.delete(self._endpoint) - raise - - # Update the timestamp to indicate credentials have been refreshed - self._creds_id = self._creds.id + start = time.time() + try: + # Use the async lock to ensure coroutine safety + async with self._async_lock: + # Double-check pattern to avoid unnecessary work + if creds_id and creds_id != self._creds_id: + # Another thread/coroutine refreshed credentials while we were waiting for the lock + return + + # Perform the actual credential refresh + try: + self._creds = await self._do_refresh_credentials() + KeyringStore.store(self._creds) + except Exception: + KeyringStore.delete(self._endpoint) + raise + + # Update the timestamp to indicate credentials have been refreshed + self._creds_id = self._creds.id + finally: + end = time.time() + print(f"----- Time to refresh credentials: {end - start:.2f} seconds") @abstractmethod async def _do_refresh_credentials(self) -> Credentials: diff --git a/src/flyte/remote/_client/auth/_client_config.py b/src/flyte/remote/_client/auth/_client_config.py index eb1150a39..edce1dafe 100644 --- a/src/flyte/remote/_client/auth/_client_config.py +++ b/src/flyte/remote/_client/auth/_client_config.py @@ -1,4 +1,5 @@ import asyncio +import time import typing from abc import abstractmethod @@ -69,17 +70,22 @@ async def get_client_config(self) -> ClientConfig: """ Retrieves the ClientConfig from the given grpc.Channel assuming AuthMetadataService is available """ - metadata_service = AuthMetadataServiceStub(self._unauthenticated_channel) - oauth2_metadata_task = metadata_service.GetOAuth2Metadata(OAuth2MetadataRequest()) - public_client_config_task = metadata_service.GetPublicClientConfig(PublicClientAuthConfigRequest()) - oauth2_metadata, public_client_config = await asyncio.gather(oauth2_metadata_task, public_client_config_task) - return ClientConfig( - token_endpoint=oauth2_metadata.token_endpoint, - authorization_endpoint=oauth2_metadata.authorization_endpoint, - redirect_uri=public_client_config.redirect_uri, - client_id=public_client_config.client_id, - scopes=public_client_config.scopes, - header_key=public_client_config.authorization_metadata_key, - device_authorization_endpoint=oauth2_metadata.device_authorization_endpoint, - audience=public_client_config.audience, - ) + start = time.time() + try: + metadata_service = AuthMetadataServiceStub(self._unauthenticated_channel) + oauth2_metadata_task = metadata_service.GetOAuth2Metadata(OAuth2MetadataRequest()) + public_client_config_task = metadata_service.GetPublicClientConfig(PublicClientAuthConfigRequest()) + oauth2_metadata, public_client_config = await asyncio.gather(oauth2_metadata_task, public_client_config_task) + return ClientConfig( + token_endpoint=oauth2_metadata.token_endpoint, + authorization_endpoint=oauth2_metadata.authorization_endpoint, + redirect_uri=public_client_config.redirect_uri, + client_id=public_client_config.client_id, + scopes=public_client_config.scopes, + header_key=public_client_config.authorization_metadata_key, + device_authorization_endpoint=oauth2_metadata.device_authorization_endpoint, + audience=public_client_config.audience, + ) + finally: + end = time.time() + print(f"----- Time to get client config: {end - start:.2f} seconds", flush=True) diff --git a/src/flyte/remote/_client/controlplane.py b/src/flyte/remote/_client/controlplane.py index ccce52b83..500c39edd 100644 --- a/src/flyte/remote/_client/controlplane.py +++ b/src/flyte/remote/_client/controlplane.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import time # Set environment variables for gRPC, this reduces log spew and avoids unnecessary warnings # before importing grpc @@ -57,9 +58,14 @@ def __init__( @classmethod async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet: - return cls( - await create_channel(endpoint, None, insecure=insecure, **kwargs), endpoint, insecure=insecure, **kwargs - ) + start = time.time() + try: + return cls( + await create_channel(endpoint, None, insecure=insecure, **kwargs), endpoint, insecure=insecure, **kwargs + ) + finally: + end = time.time() + print(f"----- Time to create channel to {endpoint}: {end - start:.2f} seconds") @classmethod async def for_api_key(cls, api_key: str, *, insecure: bool = False, **kwargs) -> ClientSet: diff --git a/src/flyte/remote/_project.py b/src/flyte/remote/_project.py index d5fedc36c..e20375554 100644 --- a/src/flyte/remote/_project.py +++ b/src/flyte/remote/_project.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time from dataclasses import dataclass from typing import AsyncIterator, Iterator, Literal, Tuple, Union @@ -31,14 +32,20 @@ async def get(cls, name: str, org: str | None = None) -> Project: :param org: The organization of the project (if applicable). """ ensure_client() - service = get_client().project_domain_service # type: ignore - resp = await service.GetProject( - project_pb2.ProjectGetRequest( - id=name, - # org=org, + start = time.time() + try: + service = get_client().project_domain_service # type: ignore + resp = await service.GetProject( + project_pb2.ProjectGetRequest( + id=name, + # org=org, + ) ) - ) - return cls(resp) + return cls(resp) + finally: + end = time.time() + print(f"----- Time to fetch project {name}: {end - start:.2f} seconds") + @syncify @classmethod @@ -61,21 +68,26 @@ async def listall( key=sort_by[0], direction=common_pb2.Sort.ASCENDING if sort_by[1] == "asc" else common_pb2.Sort.DESCENDING ) # org = get_common_config().org - while True: - resp = await get_client().project_domain_service.ListProjects( # type: ignore - project_pb2.ProjectListRequest( - limit=100, - token=token, - filters=filters, - sort_by=sort_pb2, - # org=org, + start = time.time() + try: + while True: + resp = await get_client().project_domain_service.ListProjects( # type: ignore + project_pb2.ProjectListRequest( + limit=100, + token=token, + filters=filters, + sort_by=sort_pb2, + # org=org, + ) ) - ) - token = resp.token - for p in resp.projects: - yield cls(p) - if not token: - break + token = resp.token + for p in resp.projects: + yield cls(p) + if not token: + break + finally: + end = time.time() + print(f"----- Time to list projects: {end - start:.2f} seconds", flush=True) def __rich_repr__(self) -> rich.repr.Result: yield "name", self.pb2.name