From 618bc9be2536280beae36299383f934a24329644 Mon Sep 17 00:00:00 2001 From: Tejas Kochar Date: Wed, 5 Nov 2025 19:17:54 +0000 Subject: [PATCH] add mypy and suppress errors for existing violations --- .github/workflows/push.yml | 10 ++ Makefile | 3 + databricks/sdk/_base_client.py | 22 +-- databricks/sdk/_widgets/__init__.py | 4 +- databricks/sdk/_widgets/ipywidgets_utils.py | 4 +- databricks/sdk/azure.py | 2 +- databricks/sdk/casing.py | 2 +- databricks/sdk/config.py | 104 ++++++------ databricks/sdk/core.py | 4 +- databricks/sdk/credentials_provider.py | 104 ++++++------ databricks/sdk/data_plane.py | 2 +- databricks/sdk/dbutils.py | 42 ++--- databricks/sdk/errors/base.py | 2 +- databricks/sdk/errors/customizer.py | 2 +- databricks/sdk/errors/deserializer.py | 2 +- databricks/sdk/errors/mapper.py | 2 +- databricks/sdk/errors/parser.py | 2 +- databricks/sdk/errors/private_link.py | 2 +- databricks/sdk/logger/round_trip_logger.py | 14 +- databricks/sdk/mixins/compute.py | 74 ++++----- databricks/sdk/mixins/files.py | 168 ++++++++++---------- databricks/sdk/mixins/files_utils.py | 16 +- databricks/sdk/mixins/jobs.py | 30 ++-- databricks/sdk/mixins/open_ai_client.py | 18 +-- databricks/sdk/mixins/sharing.py | 2 +- databricks/sdk/mixins/workspace.py | 2 +- databricks/sdk/oauth.py | 56 +++---- databricks/sdk/oidc.py | 2 +- databricks/sdk/oidc_token_supplier.py | 2 +- databricks/sdk/retries.py | 6 +- databricks/sdk/runtime/__init__.py | 12 +- databricks/sdk/runtime/dbutils_stub.py | 56 +++---- pyproject.toml | 21 ++- tests/conftest.py | 6 +- tests/integration/conftest.py | 2 +- tests/integration/test_auth.py | 2 +- tests/integration/test_clusters.py | 2 +- tests/integration/test_commands.py | 2 +- tests/integration/test_dbconnect.py | 4 +- tests/integration/test_dbutils.py | 2 +- tests/integration/test_deployment.py | 2 +- tests/integration/test_external_browser.py | 2 +- tests/integration/test_files.py | 2 +- tests/integration/test_iam.py | 2 +- tests/integration/test_jobs.py | 2 +- tests/test_auth_manual_tests.py | 2 +- tests/test_base_client.py | 8 +- tests/test_client.py | 2 +- tests/test_compute_mixins.py | 2 +- tests/test_config.py | 2 +- tests/test_core.py | 2 +- tests/test_dbfs_mixins.py | 4 +- tests/test_dbutils.py | 2 +- tests/test_errors.py | 8 +- tests/test_fieldmask.py | 2 +- tests/test_files.py | 106 ++++++------ tests/test_files_utils.py | 74 ++++----- tests/test_internal.py | 6 +- tests/test_metadata_service_auth.py | 2 +- tests/test_model_serving_auth.py | 2 +- tests/test_oidc.py | 20 +-- tests/test_oidc_token_supplier.py | 2 +- tests/test_open_ai_mixin.py | 2 +- tests/test_refreshable.py | 6 +- tests/test_retries.py | 12 +- tests/test_user_agent.py | 2 +- tests/testdata/test_casing.py | 2 +- 67 files changed, 563 insertions(+), 531 deletions(-) diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 71e1e5531..cc3db88ae 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -40,6 +40,16 @@ jobs: - name: Fail on differences run: git diff --exit-code + type-check: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Run mypy type checking + run: make dev mypy + check-manifest: runs-on: ubuntu-latest diff --git a/Makefile b/Makefile index c147f4074..78b189cec 100644 --- a/Makefile +++ b/Makefile @@ -24,6 +24,9 @@ lint: pycodestyle databricks autoflake --check-diff --quiet --recursive databricks +mypy: + python -m mypy databricks tests + test: pytest -m 'not integration and not benchmark' --cov=databricks --cov-report html tests diff --git a/databricks/sdk/_base_client.py b/databricks/sdk/_base_client.py index 33dbd17c9..2de5f7333 100644 --- a/databricks/sdk/_base_client.py +++ b/databricks/sdk/_base_client.py @@ -7,8 +7,8 @@ from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List, Optional, Type, Union) -import requests -import requests.adapters +import requests # type: ignore[import-untyped] +import requests.adapters # type: ignore[import-untyped] from . import useragent from .casing import Casing @@ -92,7 +92,7 @@ def __init__( http_adapter = requests.adapters.HTTPAdapter( pool_connections=max_connections_per_pool or 20, pool_maxsize=max_connection_pools or 20, - pool_block=pool_block, + pool_block=pool_block, # type: ignore[arg-type] ) self._session.mount("https://", http_adapter) @@ -100,8 +100,8 @@ def __init__( self._http_timeout_seconds = http_timeout_seconds or 60 self._error_parser = _Parser( - extra_error_customizers=extra_error_customizers, - debug_headers=debug_headers, + extra_error_customizers=extra_error_customizers, # type: ignore[arg-type] + debug_headers=debug_headers, # type: ignore[arg-type] ) def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest: @@ -127,7 +127,7 @@ def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]: # {'filter_by.user_ids': [123, 456]} # See the following for more information: # https://cloud.google.com/endpoints/docs/grpc-service-config/reference/rpc/google.api#google.api.HttpRule - def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: + def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: # type: ignore[misc] for k1, v1 in d.items(): if isinstance(v1, dict): v1 = dict(flatten_dict(v1)) @@ -281,7 +281,7 @@ def _perform( raw: bool = False, files=None, data=None, - auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None, + auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None, # type: ignore[assignment] ): response = self._session.request( method, @@ -305,7 +305,7 @@ def _perform( def _record_request_log(self, response: requests.Response, raw: bool = False) -> None: if not logger.isEnabledFor(logging.DEBUG): return - logger.debug(RoundTrip(response, self._debug_headers, self._debug_truncate_bytes, raw).generate()) + logger.debug(RoundTrip(response, self._debug_headers, self._debug_truncate_bytes, raw).generate()) # type: ignore[arg-type] class _RawResponse(ABC): @@ -343,7 +343,7 @@ def _open(self) -> None: if self._closed: raise ValueError("I/O operation on closed file") if not self._content: - self._content = self._response.iter_content(chunk_size=self._chunk_size, decode_unicode=False) + self._content = self._response.iter_content(chunk_size=self._chunk_size, decode_unicode=False) # type: ignore[arg-type] def __enter__(self) -> BinaryIO: self._open() @@ -372,7 +372,7 @@ def read(self, n: int = -1) -> bytes: while remaining_bytes > 0 or read_everything: if len(self._buffer) == 0: try: - self._buffer = next(self._content) + self._buffer = next(self._content) # type: ignore[arg-type] except StopIteration: break bytes_available = len(self._buffer) @@ -416,7 +416,7 @@ def __next__(self) -> bytes: return self.read(1) def __iter__(self) -> Iterator[bytes]: - return self._content + return self._content # type: ignore[return-value] def __exit__( self, diff --git a/databricks/sdk/_widgets/__init__.py b/databricks/sdk/_widgets/__init__.py index 3f9c4eefc..686ae08e2 100644 --- a/databricks/sdk/_widgets/__init__.py +++ b/databricks/sdk/_widgets/__init__.py @@ -38,7 +38,7 @@ def _remove_all(self): # We only use ipywidgets if we are in a notebook interactive shell otherwise we raise error, # to fallback to using default_widgets. Also, users WILL have IPython in their notebooks (jupyter), # because we DO NOT SUPPORT any other notebook backends, and hence fallback to default_widgets. - from IPython.core.getipython import get_ipython + from IPython.core.getipython import get_ipython # type: ignore[import-not-found] # Detect if we are in an interactive notebook by iterating over the mro of the current ipython instance, # to find ZMQInteractiveShell (jupyter). When used from REPL or file, this check will fail, since the @@ -79,5 +79,5 @@ def _remove_all(self): except: from .default_widgets_utils import DefaultValueOnlyWidgetUtils - widget_impl = DefaultValueOnlyWidgetUtils + widget_impl = DefaultValueOnlyWidgetUtils # type: ignore[assignment, misc] logging.debug("Using default_value_only implementation for dbutils.") diff --git a/databricks/sdk/_widgets/ipywidgets_utils.py b/databricks/sdk/_widgets/ipywidgets_utils.py index 3caff486d..cbf1f2e9c 100644 --- a/databricks/sdk/_widgets/ipywidgets_utils.py +++ b/databricks/sdk/_widgets/ipywidgets_utils.py @@ -1,7 +1,7 @@ import typing -from IPython.core.display_functions import display -from ipywidgets.widgets import (ValueWidget, Widget, widget_box, +from IPython.core.display_functions import display # type: ignore[import-not-found] +from ipywidgets.widgets import (ValueWidget, Widget, widget_box, # type: ignore[import-not-found,import-untyped] widget_selection, widget_string) from .default_widgets_utils import WidgetUtils diff --git a/databricks/sdk/azure.py b/databricks/sdk/azure.py index 9bb000d76..a66fc126a 100644 --- a/databricks/sdk/azure.py +++ b/databricks/sdk/azure.py @@ -4,7 +4,7 @@ from .service.provisioning import Workspace -def add_workspace_id_header(cfg: "Config", headers: Dict[str, str]): +def add_workspace_id_header(cfg: "Config", headers: Dict[str, str]): # type: ignore[name-defined] if cfg.azure_workspace_resource_id: headers["X-Databricks-Azure-Workspace-Resource-Id"] = cfg.azure_workspace_resource_id diff --git a/databricks/sdk/casing.py b/databricks/sdk/casing.py index 5e0af17b4..29112a3a1 100644 --- a/databricks/sdk/casing.py +++ b/databricks/sdk/casing.py @@ -4,7 +4,7 @@ class _Name(object): def __init__(self, raw_name: str): # self._segments = [] - segment = [] + segment = [] # type: ignore[var-annotated] for ch in raw_name: if ch.isupper(): if segment: diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 879ba64ec..1ba19c805 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -8,7 +8,7 @@ import urllib.parse from typing import Dict, Iterable, List, Optional -import requests +import requests # type: ignore[import-untyped] from . import useragent from ._base_client import _fix_host_if_needed @@ -28,10 +28,10 @@ class ConfigAttribute: """Configuration attribute metadata and descriptor protocols.""" # name and transform are discovered from Config.__new__ - name: str = None + name: str = None # type: ignore[assignment] transform: type = str - def __init__(self, env: str = None, auth: str = None, sensitive: bool = False): + def __init__(self, env: str = None, auth: str = None, sensitive: bool = False): # type: ignore[assignment] self.env = env self.auth = auth self.sensitive = sensitive @@ -41,7 +41,7 @@ def __get__(self, cfg: "Config", owner): return None return cfg._inner.get(self.name, None) - def __set__(self, cfg: "Config", value: any): + def __set__(self, cfg: "Config", value: any): # type: ignore[valid-type] cfg._inner[self.name] = self.transform(value) def __repr__(self) -> str: @@ -59,58 +59,58 @@ def with_user_agent_extra(key: str, value: str): class Config: - host: str = ConfigAttribute(env="DATABRICKS_HOST") - account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID") + host: str = ConfigAttribute(env="DATABRICKS_HOST") # type: ignore[assignment] + account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID") # type: ignore[assignment] # PAT token. - token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True) + token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True) # type: ignore[assignment] # Audience for OIDC ID token source accepting an audience as a parameter. # For example, the GitHub action ID token source. - token_audience: str = ConfigAttribute(env="DATABRICKS_TOKEN_AUDIENCE", auth="github-oidc") + token_audience: str = ConfigAttribute(env="DATABRICKS_TOKEN_AUDIENCE", auth="github-oidc") # type: ignore[assignment] # Environment variable for OIDC token. - oidc_token_env: str = ConfigAttribute(env="DATABRICKS_OIDC_TOKEN_ENV", auth="env-oidc") - oidc_token_filepath: str = ConfigAttribute(env="DATABRICKS_OIDC_TOKEN_FILE", auth="file-oidc") - - username: str = ConfigAttribute(env="DATABRICKS_USERNAME", auth="basic") - password: str = ConfigAttribute(env="DATABRICKS_PASSWORD", auth="basic", sensitive=True) - - client_id: str = ConfigAttribute(env="DATABRICKS_CLIENT_ID", auth="oauth") - client_secret: str = ConfigAttribute(env="DATABRICKS_CLIENT_SECRET", auth="oauth", sensitive=True) - profile: str = ConfigAttribute(env="DATABRICKS_CONFIG_PROFILE") - config_file: str = ConfigAttribute(env="DATABRICKS_CONFIG_FILE") - google_service_account: str = ConfigAttribute(env="DATABRICKS_GOOGLE_SERVICE_ACCOUNT", auth="google") - google_credentials: str = ConfigAttribute(env="GOOGLE_CREDENTIALS", auth="google", sensitive=True) - azure_workspace_resource_id: str = ConfigAttribute(env="DATABRICKS_AZURE_RESOURCE_ID", auth="azure") - azure_use_msi: bool = ConfigAttribute(env="ARM_USE_MSI", auth="azure") - azure_client_secret: str = ConfigAttribute(env="ARM_CLIENT_SECRET", auth="azure", sensitive=True) - azure_client_id: str = ConfigAttribute(env="ARM_CLIENT_ID", auth="azure") - azure_tenant_id: str = ConfigAttribute(env="ARM_TENANT_ID", auth="azure") - azure_environment: str = ConfigAttribute(env="ARM_ENVIRONMENT") - databricks_cli_path: str = ConfigAttribute(env="DATABRICKS_CLI_PATH") - auth_type: str = ConfigAttribute(env="DATABRICKS_AUTH_TYPE") - cluster_id: str = ConfigAttribute(env="DATABRICKS_CLUSTER_ID") - warehouse_id: str = ConfigAttribute(env="DATABRICKS_WAREHOUSE_ID") - serverless_compute_id: str = ConfigAttribute(env="DATABRICKS_SERVERLESS_COMPUTE_ID") - skip_verify: bool = ConfigAttribute() - http_timeout_seconds: float = ConfigAttribute() - debug_truncate_bytes: int = ConfigAttribute(env="DATABRICKS_DEBUG_TRUNCATE_BYTES") - debug_headers: bool = ConfigAttribute(env="DATABRICKS_DEBUG_HEADERS") - rate_limit: int = ConfigAttribute(env="DATABRICKS_RATE_LIMIT") - retry_timeout_seconds: int = ConfigAttribute() + oidc_token_env: str = ConfigAttribute(env="DATABRICKS_OIDC_TOKEN_ENV", auth="env-oidc") # type: ignore[assignment] + oidc_token_filepath: str = ConfigAttribute(env="DATABRICKS_OIDC_TOKEN_FILE", auth="file-oidc") # type: ignore[assignment] + + username: str = ConfigAttribute(env="DATABRICKS_USERNAME", auth="basic") # type: ignore[assignment] + password: str = ConfigAttribute(env="DATABRICKS_PASSWORD", auth="basic", sensitive=True) # type: ignore[assignment] + + client_id: str = ConfigAttribute(env="DATABRICKS_CLIENT_ID", auth="oauth") # type: ignore[assignment] + client_secret: str = ConfigAttribute(env="DATABRICKS_CLIENT_SECRET", auth="oauth", sensitive=True) # type: ignore[assignment] + profile: str = ConfigAttribute(env="DATABRICKS_CONFIG_PROFILE") # type: ignore[assignment] + config_file: str = ConfigAttribute(env="DATABRICKS_CONFIG_FILE") # type: ignore[assignment] + google_service_account: str = ConfigAttribute(env="DATABRICKS_GOOGLE_SERVICE_ACCOUNT", auth="google") # type: ignore[assignment] + google_credentials: str = ConfigAttribute(env="GOOGLE_CREDENTIALS", auth="google", sensitive=True) # type: ignore[assignment] + azure_workspace_resource_id: str = ConfigAttribute(env="DATABRICKS_AZURE_RESOURCE_ID", auth="azure") # type: ignore[assignment] + azure_use_msi: bool = ConfigAttribute(env="ARM_USE_MSI", auth="azure") # type: ignore[assignment] + azure_client_secret: str = ConfigAttribute(env="ARM_CLIENT_SECRET", auth="azure", sensitive=True) # type: ignore[assignment] + azure_client_id: str = ConfigAttribute(env="ARM_CLIENT_ID", auth="azure") # type: ignore[assignment] + azure_tenant_id: str = ConfigAttribute(env="ARM_TENANT_ID", auth="azure") # type: ignore[assignment] + azure_environment: str = ConfigAttribute(env="ARM_ENVIRONMENT") # type: ignore[assignment] + databricks_cli_path: str = ConfigAttribute(env="DATABRICKS_CLI_PATH") # type: ignore[assignment] + auth_type: str = ConfigAttribute(env="DATABRICKS_AUTH_TYPE") # type: ignore[assignment] + cluster_id: str = ConfigAttribute(env="DATABRICKS_CLUSTER_ID") # type: ignore[assignment] + warehouse_id: str = ConfigAttribute(env="DATABRICKS_WAREHOUSE_ID") # type: ignore[assignment] + serverless_compute_id: str = ConfigAttribute(env="DATABRICKS_SERVERLESS_COMPUTE_ID") # type: ignore[assignment] + skip_verify: bool = ConfigAttribute() # type: ignore[assignment] + http_timeout_seconds: float = ConfigAttribute() # type: ignore[assignment] + debug_truncate_bytes: int = ConfigAttribute(env="DATABRICKS_DEBUG_TRUNCATE_BYTES") # type: ignore[assignment] + debug_headers: bool = ConfigAttribute(env="DATABRICKS_DEBUG_HEADERS") # type: ignore[assignment] + rate_limit: int = ConfigAttribute(env="DATABRICKS_RATE_LIMIT") # type: ignore[assignment] + retry_timeout_seconds: int = ConfigAttribute() # type: ignore[assignment] metadata_service_url = ConfigAttribute( env="DATABRICKS_METADATA_SERVICE_URL", auth="metadata-service", sensitive=True, ) - max_connection_pools: int = ConfigAttribute() - max_connections_per_pool: int = ConfigAttribute() + max_connection_pools: int = ConfigAttribute() # type: ignore[assignment] + max_connections_per_pool: int = ConfigAttribute() # type: ignore[assignment] databricks_environment: Optional[DatabricksEnvironment] = None - disable_async_token_refresh: bool = ConfigAttribute(env="DATABRICKS_DISABLE_ASYNC_TOKEN_REFRESH") + disable_async_token_refresh: bool = ConfigAttribute(env="DATABRICKS_DISABLE_ASYNC_TOKEN_REFRESH") # type: ignore[assignment] - disable_experimental_files_api_client: bool = ConfigAttribute( + disable_experimental_files_api_client: bool = ConfigAttribute( # type: ignore[assignment] env="DATABRICKS_DISABLE_EXPERIMENTAL_FILES_API_CLIENT" ) @@ -217,8 +217,8 @@ def __init__( **kwargs, ): self._header_factory = None - self._inner = {} - self._user_agent_other_info = [] + self._inner = {} # type: ignore[var-annotated] + self._user_agent_other_info = [] # type: ignore[var-annotated] if credentials_strategy and credentials_provider: raise ValueError("When providing `credentials_strategy` field, `credential_provider` cannot be specified.") if credentials_provider: @@ -284,11 +284,11 @@ def parse_dsn(dsn: str) -> "Config": if attr.name not in query: continue kwargs[attr.name] = query[attr.name] - return Config(**kwargs) + return Config(**kwargs) # type: ignore[arg-type] def authenticate(self) -> Dict[str, str]: """Returns a list of fresh authentication headers""" - return self._header_factory() + return self._header_factory() # type: ignore[misc] def as_dict(self) -> dict: return self._inner @@ -314,7 +314,7 @@ def environment(self) -> DatabricksEnvironment: for environment in ALL_ENVS: if environment.cloud != Cloud.AZURE: continue - if environment.azure_environment.name != azure_env: + if environment.azure_environment.name != azure_env: # type: ignore[union-attr] continue if environment.dns_zone.startswith(".dev") or environment.dns_zone.startswith(".staging"): continue @@ -343,7 +343,7 @@ def is_account_client(self) -> bool: @property def arm_environment(self) -> AzureEnvironment: - return self.environment.azure_environment + return self.environment.azure_environment # type: ignore[return-value] @property def effective_azure_login_app_id(self): @@ -414,11 +414,11 @@ def debug_string(self) -> str: buf.append(f"Env: {', '.join(envs_used)}") return ". ".join(buf) - def to_dict(self) -> Dict[str, any]: + def to_dict(self) -> Dict[str, any]: # type: ignore[valid-type] return self._inner @property - def sql_http_path(self) -> Optional[str]: + def sql_http_path(self) -> Optional[str]: # type: ignore[return] """(Experimental) Return HTTP path for SQL Drivers. If `cluster_id` or `warehouse_id` are configured, return a valid HTTP Path argument @@ -465,8 +465,8 @@ def attributes(cls) -> Iterable[ConfigAttribute]: v.name = name v.transform = anno.get(name, str) attrs.append(v) - cls._attributes = attrs - return cls._attributes + cls._attributes = attrs # type: ignore[attr-defined] + return cls._attributes # type: ignore[attr-defined] def _fix_host_if_needed(self): updated_host = _fix_host_if_needed(self.host) @@ -499,7 +499,7 @@ def load_azure_tenant_id(self): self.azure_tenant_id = path_segments[1] logger.debug(f"Loaded tenant ID: {self.azure_tenant_id}") - def _set_inner_config(self, keyword_args: Dict[str, any]): + def _set_inner_config(self, keyword_args: Dict[str, any]): # type: ignore[valid-type] for attr in self.attributes(): if attr.name not in keyword_args: continue diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index 92e3dbf89..63870ca23 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -5,7 +5,7 @@ from ._base_client import _BaseClient from .config import * # To preserve backwards compatibility (as these definitions were previously in this module) -from .credentials_provider import * +from .credentials_provider import * # type: ignore[no-redef] from .errors import DatabricksError, _ErrorCustomizer from .oauth import retrieve_token @@ -80,7 +80,7 @@ def do( if url is None: # Remove extra `/` from path for Files API # Once we've fixed the OpenAPI spec, we can remove this - path = re.sub("^/api/2.0/fs/files//", "/api/2.0/fs/files/", path) + path = re.sub("^/api/2.0/fs/files//", "/api/2.0/fs/files/", path) # type: ignore[arg-type] url = f"{self._cfg.host}{path}" return self._api_client.do( method=method, diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 022482370..9af4ea69d 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -15,7 +15,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import google.auth # type: ignore -import requests +import requests # type: ignore[import-untyped] from google.auth import impersonated_credentials # type: ignore from google.auth.transport.requests import Request # type: ignore from google.oauth2 import service_account # type: ignore @@ -53,7 +53,7 @@ class CredentialsStrategy(abc.ABC): def auth_type(self) -> str: ... @abc.abstractmethod - def __call__(self, cfg: "Config") -> CredentialsProvider: ... + def __call__(self, cfg: "Config") -> CredentialsProvider: ... # type: ignore[name-defined] class OauthCredentialsStrategy(CredentialsStrategy): @@ -63,7 +63,7 @@ class OauthCredentialsStrategy(CredentialsStrategy): def __init__( self, auth_type: str, - headers_provider: Callable[["Config"], OAuthCredentialsProvider], + headers_provider: Callable[["Config"], OAuthCredentialsProvider], # type: ignore[name-defined] ): self._headers_provider = headers_provider self._auth_type = auth_type @@ -71,10 +71,10 @@ def __init__( def auth_type(self) -> str: return self._auth_type - def __call__(self, cfg: "Config") -> OAuthCredentialsProvider: + def __call__(self, cfg: "Config") -> OAuthCredentialsProvider: # type: ignore[name-defined] return self._headers_provider(cfg) - def oauth_token(self, cfg: "Config") -> oauth.Token: + def oauth_token(self, cfg: "Config") -> oauth.Token: # type: ignore[name-defined] return self._headers_provider(cfg).oauth_token() @@ -84,17 +84,17 @@ def credentials_strategy(name: str, require: List[str]): attribute names to be present for this function to be called.""" def inner( - func: Callable[["Config"], CredentialsProvider], + func: Callable[["Config"], CredentialsProvider], # type: ignore[name-defined] ) -> CredentialsStrategy: @functools.wraps(func) - def wrapper(cfg: "Config") -> Optional[CredentialsProvider]: + def wrapper(cfg: "Config") -> Optional[CredentialsProvider]: # type: ignore[name-defined] for attr in require: if not getattr(cfg, attr): return None return func(cfg) - wrapper.auth_type = lambda: name - return wrapper + wrapper.auth_type = lambda: name # type: ignore[attr-defined] + return wrapper # type: ignore[return-value] return inner @@ -110,22 +110,22 @@ def oauth_credentials_strategy(name: str, require: List[str]): """ def inner( - func: Callable[["Config"], OAuthCredentialsProvider], + func: Callable[["Config"], OAuthCredentialsProvider], # type: ignore[name-defined] ) -> OauthCredentialsStrategy: @functools.wraps(func) - def wrapper(cfg: "Config") -> Optional[OAuthCredentialsProvider]: + def wrapper(cfg: "Config") -> Optional[OAuthCredentialsProvider]: # type: ignore[name-defined] for attr in require: if not getattr(cfg, attr): return None return func(cfg) - return OauthCredentialsStrategy(name, wrapper) + return OauthCredentialsStrategy(name, wrapper) # type: ignore[arg-type] return inner @credentials_strategy("basic", ["host", "username", "password"]) -def basic_auth(cfg: "Config") -> CredentialsProvider: +def basic_auth(cfg: "Config") -> CredentialsProvider: # type: ignore[name-defined] """Given username and password, add base64-encoded Basic credentials""" encoded = base64.b64encode(f"{cfg.username}:{cfg.password}".encode()).decode() static_credentials = {"Authorization": f"Basic {encoded}"} @@ -137,7 +137,7 @@ def inner() -> Dict[str, str]: @credentials_strategy("pat", ["host", "token"]) -def pat_auth(cfg: "Config") -> CredentialsProvider: +def pat_auth(cfg: "Config") -> CredentialsProvider: # type: ignore[name-defined] """Adds Databricks Personal Access Token to every request""" static_credentials = {"Authorization": f"Bearer {cfg.token}"} @@ -148,7 +148,7 @@ def inner() -> Dict[str, str]: @credentials_strategy("runtime", []) -def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]: +def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]: # type: ignore[name-defined] if "DATABRICKS_RUNTIME_VERSION" not in os.environ: return None @@ -177,7 +177,7 @@ def runtime_native_auth(cfg: "Config") -> Optional[CredentialsProvider]: @oauth_credentials_strategy("oauth-m2m", ["host", "client_id", "client_secret"]) -def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]: +def oauth_service_principal(cfg: "Config") -> Optional[CredentialsProvider]: # type: ignore[name-defined] """Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request, if /oidc/.well-known/oauth-authorization-server is available on the given host. """ @@ -205,7 +205,7 @@ def token() -> oauth.Token: @credentials_strategy("external-browser", ["host", "auth_type"]) -def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: +def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: # type: ignore[name-defined] if cfg.auth_type != "external-browser": return None @@ -246,7 +246,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: oidc_endpoints=oidc_endpoints, client_id=client_id, redirect_url=redirect_url, - client_secret=client_secret, + client_secret=client_secret, # type: ignore[arg-type] ) consent = oauth_client.initiate_consent() if not consent: @@ -257,7 +257,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: return credentials(cfg) -def _ensure_host_present(cfg: "Config", token_source_for: Callable[[str], oauth.TokenSource]): +def _ensure_host_present(cfg: "Config", token_source_for: Callable[[str], oauth.TokenSource]): # type: ignore[name-defined] """Resolves Azure Databricks workspace URL from ARM Resource ID""" if cfg.host: return @@ -270,7 +270,7 @@ def _ensure_host_present(cfg: "Config", token_source_for: Callable[[str], oauth. headers={"Authorization": f"Bearer {token.access_token}"}, ) if not resp.ok: - raise ValueError(f"Cannot resolve Azure Databricks workspace: {resp.content}") + raise ValueError(f"Cannot resolve Azure Databricks workspace: {resp.content}") # type: ignore[str-bytes-safe] cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}" @@ -278,7 +278,7 @@ def _ensure_host_present(cfg: "Config", token_source_for: Callable[[str], oauth. "azure-client-secret", ["is_azure", "azure_client_id", "azure_client_secret"], ) -def azure_service_principal(cfg: "Config") -> CredentialsProvider: +def azure_service_principal(cfg: "Config") -> CredentialsProvider: # type: ignore[name-defined] """Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request, while automatically resolving different Azure environment endpoints. """ @@ -361,7 +361,7 @@ def token() -> oauth.Token: def _oidc_credentials_provider( - cfg: "Config", supplier_factory: Callable[[], Any], provider_name: str + cfg: "Config", supplier_factory: Callable[[], Any], provider_name: str # type: ignore[name-defined] ) -> Optional[CredentialsProvider]: """ Generic OIDC credentials provider that works with any OIDC token supplier. @@ -427,7 +427,7 @@ def token() -> oauth.Token: @oauth_credentials_strategy("github-oidc", ["host", "client_id"]) -def github_oidc(cfg: "Config") -> Optional[CredentialsProvider]: +def github_oidc(cfg: "Config") -> Optional[CredentialsProvider]: # type: ignore[name-defined] """ GitHub OIDC authentication uses a Token Supplier to get a JWT Token and exchanges it for a Databricks Token. @@ -442,7 +442,7 @@ def github_oidc(cfg: "Config") -> Optional[CredentialsProvider]: @oauth_credentials_strategy("azure-devops-oidc", ["host", "client_id"]) -def azure_devops_oidc(cfg: "Config") -> Optional[CredentialsProvider]: +def azure_devops_oidc(cfg: "Config") -> Optional[CredentialsProvider]: # type: ignore[name-defined] """ Azure DevOps OIDC authentication uses a Token Supplier to get a JWT Token and exchanges it for a Databricks Token. @@ -457,7 +457,7 @@ def azure_devops_oidc(cfg: "Config") -> Optional[CredentialsProvider]: @oauth_credentials_strategy("github-oidc-azure", ["host", "azure_client_id"]) -def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: +def github_oidc_azure(cfg: "Config") -> Optional[CredentialsProvider]: # type: ignore[name-defined] if "ACTIONS_ID_TOKEN_REQUEST_TOKEN" not in os.environ: # not in GitHub actions return None @@ -499,10 +499,10 @@ def refreshed_headers() -> Dict[str, str]: token = inner.token() return {"Authorization": f"{token.token_type} {token.access_token}"} - def token() -> oauth.Token: + def token() -> oauth.Token: # type: ignore[no-redef] return inner.token() - return OAuthCredentialsProvider(refreshed_headers, token) + return OAuthCredentialsProvider(refreshed_headers, token) # type: ignore[arg-type] GcpScopes = [ @@ -512,7 +512,7 @@ def token() -> oauth.Token: @oauth_credentials_strategy("google-credentials", ["host", "google_credentials"]) -def google_credentials(cfg: "Config") -> Optional[CredentialsProvider]: +def google_credentials(cfg: "Config") -> Optional[CredentialsProvider]: # type: ignore[name-defined] if not cfg.is_gcp: return None # Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string. @@ -548,7 +548,7 @@ def refreshed_headers() -> Dict[str, str]: @oauth_credentials_strategy("google-id", ["host", "google_service_account"]) -def google_id(cfg: "Config") -> Optional[CredentialsProvider]: +def google_id(cfg: "Config") -> Optional[CredentialsProvider]: # type: ignore[name-defined] if not cfg.is_gcp: return None credentials, _project_id = google.auth.default() @@ -604,7 +604,7 @@ def __init__( self._expiry_field = expiry_field @staticmethod - def _parse_expiry(expiry: str) -> datetime: + def _parse_expiry(expiry: str) -> datetime: # type: ignore[return] expiry = expiry.rstrip("Z").split(".")[0] for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"): try: @@ -726,7 +726,7 @@ def is_human_user(self) -> bool: return "upn" in self.token().jwt_claims() @staticmethod - def for_resource(cfg: "Config", resource: str) -> "AzureCliTokenSource": + def for_resource(cfg: "Config", resource: str) -> "AzureCliTokenSource": # type: ignore[name-defined] subscription = AzureCliTokenSource.get_subscription(cfg) if subscription is not None: token_source = AzureCliTokenSource(resource, subscription=subscription, tenant=cfg.azure_tenant_id) @@ -744,7 +744,7 @@ def for_resource(cfg: "Config", resource: str) -> "AzureCliTokenSource": return token_source @staticmethod - def get_subscription(cfg: "Config") -> Optional[str]: + def get_subscription(cfg: "Config") -> Optional[str]: # type: ignore[name-defined] resource = cfg.azure_workspace_resource_id if resource is None or resource == "": return None @@ -756,7 +756,7 @@ def get_subscription(cfg: "Config") -> Optional[str]: @credentials_strategy("azure-cli", ["is_azure"]) -def azure_cli(cfg: "Config") -> Optional[CredentialsProvider]: +def azure_cli(cfg: "Config") -> Optional[CredentialsProvider]: # type: ignore[name-defined] """Adds refreshed OAuth token granted by `az login` command to every request.""" cfg.load_azure_tenant_id() token_source = None @@ -800,7 +800,7 @@ def inner() -> Dict[str, str]: class DatabricksCliTokenSource(CliTokenSource): """Obtain the token granted by `databricks auth login` CLI command""" - def __init__(self, cfg: "Config"): + def __init__(self, cfg: "Config"): # type: ignore[name-defined] args = ["auth", "token", "--host", cfg.host] if cfg.is_account_client: args += ["--account-id", cfg.account_id] @@ -852,7 +852,7 @@ def _find_executable(name) -> str: @oauth_credentials_strategy("databricks-cli", ["host"]) -def databricks_cli(cfg: "Config") -> Optional[CredentialsProvider]: +def databricks_cli(cfg: "Config") -> Optional[CredentialsProvider]: # type: ignore[name-defined] try: token_source = DatabricksCliTokenSource(cfg) except FileNotFoundError as e: @@ -887,7 +887,7 @@ class MetadataServiceTokenSource(oauth.Refreshable): METADATA_SERVICE_HOST_HEADER = "X-Databricks-Host" _metadata_service_timeout = 10 # seconds - def __init__(self, cfg: "Config"): + def __init__(self, cfg: "Config"): # type: ignore[name-defined] super().__init__() self.url = cfg.metadata_service_url self.host = cfg.host @@ -906,7 +906,7 @@ def refresh(self) -> oauth.Token: "no_proxy": "localhost,127.0.0.1" }, ) - json_resp: dict[str, Union[str, float]] = resp.json() + json_resp: dict[str, Union[str, float]] = resp.json() # type: ignore[misc] access_token = json_resp.get("access_token", None) if access_token is None: raise ValueError("Metadata Service returned empty token") @@ -916,15 +916,15 @@ def refresh(self) -> oauth.Token: if json_resp["expires_on"] in ["", None]: raise ValueError("Metadata Service returned invalid expiry") try: - expiry = datetime.fromtimestamp(json_resp["expires_on"]) + expiry = datetime.fromtimestamp(json_resp["expires_on"]) # type: ignore[arg-type] except: raise ValueError("Metadata Service returned invalid expiry") - return oauth.Token(access_token=access_token, token_type=token_type, expiry=expiry) + return oauth.Token(access_token=access_token, token_type=token_type, expiry=expiry) # type: ignore[arg-type] @credentials_strategy("metadata-service", ["host", "metadata_service_url"]) -def metadata_service(cfg: "Config") -> Optional[CredentialsProvider]: +def metadata_service(cfg: "Config") -> Optional[CredentialsProvider]: # type: ignore[name-defined] """Adds refreshed token granted by Databricks Metadata Service to every request.""" token_source = MetadataServiceTokenSource(cfg) @@ -951,7 +951,7 @@ def __init__(self, credential_type: Optional[str]): self.refresh_duration = 300 # 300 Seconds self.credential_type = credential_type - def should_fetch_model_serving_environment_oauth() -> bool: + def should_fetch_model_serving_environment_oauth() -> bool: # type: ignore[misc] """ Check whether this is the model serving environment Additionally check if the oauth token file path exists @@ -975,7 +975,7 @@ def _get_model_dependency_oauth_token(self, should_retry=True) -> str: with open(ModelServingAuthProvider._MODEL_DEPENDENCY_OAUTH_TOKEN_FILE_PATH) as f: oauth_dict = json.load(f) self.current_token = oauth_dict["OAUTH_TOKEN"][0]["oauthTokenValue"] - self.expiry_time = time.time() + self.refresh_duration + self.expiry_time = time.time() + self.refresh_duration # type: ignore[assignment] except Exception as e: # sleep and retry in case of any race conditions with OAuth refreshing if should_retry: @@ -989,7 +989,7 @@ def _get_model_dependency_oauth_token(self, should_retry=True) -> str: raise RuntimeError( "Unable to read OAuth credentials from the file mounted in Databricks Model Serving" ) from e - return self.current_token + return self.current_token # type: ignore[return-value] def _get_invokers_token(self): main_thread = threading.main_thread() @@ -1011,15 +1011,15 @@ def get_databricks_host_token(self) -> Optional[Tuple[str, str]]: host = os.environ.get("DATABRICKS_MODEL_SERVING_HOST_URL") or os.environ.get("DB_MODEL_SERVING_HOST_URL") if self.credential_type == ModelServingAuthProvider.USER_CREDENTIALS: - return (host, self._get_invokers_token()) + return (host, self._get_invokers_token()) # type: ignore[return-value] else: - return (host, self._get_model_dependency_oauth_token()) + return (host, self._get_model_dependency_oauth_token()) # type: ignore[return-value] -def model_serving_auth_visitor(cfg: "Config", credential_type: Optional[str] = None) -> Optional[CredentialsProvider]: +def model_serving_auth_visitor(cfg: "Config", credential_type: Optional[str] = None) -> Optional[CredentialsProvider]: # type: ignore[name-defined] try: model_serving_auth_provider = ModelServingAuthProvider(credential_type) - host, token = model_serving_auth_provider.get_databricks_host_token() + host, token = model_serving_auth_provider.get_databricks_host_token() # type: ignore[misc] if token is None: raise ValueError( "Got malformed auth (empty token) when fetching auth implicitly available in Model Serving Environment. Please contact Databricks support" @@ -1036,14 +1036,14 @@ def model_serving_auth_visitor(cfg: "Config", credential_type: Optional[str] = N def inner() -> Dict[str, str]: # Call here again to get the refreshed token - _, token = model_serving_auth_provider.get_databricks_host_token() + _, token = model_serving_auth_provider.get_databricks_host_token() # type: ignore[misc] return {"Authorization": f"Bearer {token}"} return inner @credentials_strategy("model-serving", []) -def model_serving_auth(cfg: "Config") -> Optional[CredentialsProvider]: +def model_serving_auth(cfg: "Config") -> Optional[CredentialsProvider]: # type: ignore[name-defined] if not ModelServingAuthProvider.should_fetch_model_serving_environment_oauth(): logger.debug("model-serving: Not in Databricks Model Serving, skipping") return None @@ -1079,7 +1079,7 @@ def __init__(self) -> None: def auth_type(self) -> str: return self._auth_type - def oauth_token(self, cfg: "Config") -> oauth.Token: + def oauth_token(self, cfg: "Config") -> oauth.Token: # type: ignore[name-defined, return] for provider in self._auth_providers: auth_type = provider.auth_type() if auth_type != self._auth_type: @@ -1087,7 +1087,7 @@ def oauth_token(self, cfg: "Config") -> oauth.Token: continue return provider.oauth_token(cfg) - def __call__(self, cfg: "Config") -> CredentialsProvider: + def __call__(self, cfg: "Config") -> CredentialsProvider: # type: ignore[name-defined] for provider in self._auth_providers: auth_type = provider.auth_type() if cfg.auth_type and auth_type != cfg.auth_type: @@ -1133,7 +1133,7 @@ def auth_type(self): else: return self.default_credentials.auth_type() - def __call__(self, cfg: "Config") -> CredentialsProvider: + def __call__(self, cfg: "Config") -> CredentialsProvider: # type: ignore[name-defined] if ModelServingAuthProvider.should_fetch_model_serving_environment_oauth(): header_factory = model_serving_auth_visitor(cfg, self.credential_type) if not header_factory: diff --git a/databricks/sdk/data_plane.py b/databricks/sdk/data_plane.py index aa772edcc..05c7fbaee 100644 --- a/databricks/sdk/data_plane.py +++ b/databricks/sdk/data_plane.py @@ -22,7 +22,7 @@ class DataPlaneTokenSource: def __init__(self, token_exchange_host: str, cpts: Callable[[], Token], disable_async: Optional[bool] = True): self._cpts = cpts self._token_exchange_host = token_exchange_host - self._token_sources = {} + self._token_sources = {} # type: ignore[var-annotated] self._disable_async = disable_async self._lock = threading.Lock() diff --git a/databricks/sdk/dbutils.py b/databricks/sdk/dbutils.py index 9c284767f..c1ef2a1ea 100644 --- a/databricks/sdk/dbutils.py +++ b/databricks/sdk/dbutils.py @@ -59,7 +59,7 @@ def ls(self, dir: str) -> List[FileInfo]: return [ FileInfo( f.path, - os.path.basename(f.path), + os.path.basename(f.path), # type: ignore[type-var] f.file_size, f.modification_time, ) @@ -91,9 +91,9 @@ def mount( self, source: str, mount_point: str, - encryption_type: str = None, - owner: str = None, - extra_configs: Dict[str, str] = None, + encryption_type: str = None, # type: ignore[assignment] + owner: str = None, # type: ignore[assignment] + extra_configs: Dict[str, str] = None, # type: ignore[assignment] ) -> bool: """Mounts the given source directory into DBFS at the given mount point""" fs = self._proxy_factory("fs") @@ -103,21 +103,21 @@ def mount( if owner: kwargs["owner"] = owner if extra_configs: - kwargs["extra_configs"] = extra_configs - return fs.mount(source, mount_point, **kwargs) + kwargs["extra_configs"] = extra_configs # type: ignore[assignment] + return fs.mount(source, mount_point, **kwargs) # type: ignore[call-arg] def unmount(self, mount_point: str) -> bool: """Deletes a DBFS mount point""" fs = self._proxy_factory("fs") - return fs.unmount(mount_point) + return fs.unmount(mount_point) # type: ignore[call-arg] def updateMount( self, source: str, mount_point: str, - encryption_type: str = None, - owner: str = None, - extra_configs: Dict[str, str] = None, + encryption_type: str = None, # type: ignore[assignment] + owner: str = None, # type: ignore[assignment] + extra_configs: Dict[str, str] = None, # type: ignore[assignment] ) -> bool: """Similar to mount(), but updates an existing mount point (if present) instead of creating a new one""" fs = self._proxy_factory("fs") @@ -127,8 +127,8 @@ def updateMount( if owner: kwargs["owner"] = owner if extra_configs: - kwargs["extra_configs"] = extra_configs - return fs.updateMount(source, mount_point, **kwargs) + kwargs["extra_configs"] = extra_configs # type: ignore[assignment] + return fs.updateMount(source, mount_point, **kwargs) # type: ignore[call-arg] def mounts(self) -> List[MountInfo]: """Displays information about what is mounted within DBFS""" @@ -186,8 +186,8 @@ def get( self, taskKey: str, key: str, - default: any = None, - debugValue: any = None, + default: any = None, # type: ignore[valid-type] + debugValue: any = None, # type: ignore[valid-type] ) -> None: """ Returns `debugValue` if present, throws an error otherwise as this implementation is always run outside of a job run @@ -198,7 +198,7 @@ def get( ) return debugValue - def set(self, key: str, value: any) -> None: + def set(self, key: str, value: any) -> None: # type: ignore[valid-type] """ Sets a task value on the current task run """ @@ -209,7 +209,7 @@ def __init__(self) -> None: class RemoteDbUtils: - def __init__(self, config: "Config" = None): + def __init__(self, config: "Config" = None): # type: ignore[assignment] # Create a shallow copy of the config to allow the use of a custom # user-agent while avoiding modifying the original config. self._config = Config() if not config else config.copy() @@ -254,8 +254,8 @@ def _running_command_context(self) -> compute.ContextStatusResponse: if self._ctx: return self._ctx self._clusters.ensure_cluster_is_running(self._cluster_id) - self._ctx = self._commands.create(cluster_id=self._cluster_id, language=compute.Language.PYTHON).result() - return self._ctx + self._ctx = self._commands.create(cluster_id=self._cluster_id, language=compute.Language.PYTHON).result() # type: ignore[assignment] + return self._ctx # type: ignore[return-value] def __getattr__(self, util) -> "_ProxyUtil": return _ProxyUtil( @@ -433,7 +433,7 @@ def _error_from_results(self, results: compute.Results): if results.cause: _LOG.debug(f'{self._ascii_escape_re.sub("", results.cause)}') - summary = self._tag_re.sub("", results.summary) + summary = self._tag_re.sub("", results.summary) # type: ignore[arg-type] summary = html.unescape(summary) exception_matches = self._exception_re.findall(summary) @@ -442,11 +442,11 @@ def _error_from_results(self, results: compute.Results): summary = summary.rstrip(" ") return summary - execution_error_matches = self._execution_error_re.findall(results.cause) + execution_error_matches = self._execution_error_re.findall(results.cause) # type: ignore[arg-type] if len(execution_error_matches) == 1: return "\n".join(execution_error_matches[0]) - error_message_matches = self._error_message_re.findall(results.cause) + error_message_matches = self._error_message_re.findall(results.cause) # type: ignore[arg-type] if len(error_message_matches) == 1: return error_message_matches[0] diff --git a/databricks/sdk/errors/base.py b/databricks/sdk/errors/base.py index c88ec60ff..f3adfdba6 100644 --- a/databricks/sdk/errors/base.py +++ b/databricks/sdk/errors/base.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional -import requests +import requests # type: ignore[import-untyped] from . import details as errdetails diff --git a/databricks/sdk/errors/customizer.py b/databricks/sdk/errors/customizer.py index 6a760b626..20ead33ff 100644 --- a/databricks/sdk/errors/customizer.py +++ b/databricks/sdk/errors/customizer.py @@ -1,7 +1,7 @@ import abc import logging -import requests +import requests # type: ignore[import-untyped] class _ErrorCustomizer(abc.ABC): diff --git a/databricks/sdk/errors/deserializer.py b/databricks/sdk/errors/deserializer.py index 5a6e0da09..cf7d9793b 100644 --- a/databricks/sdk/errors/deserializer.py +++ b/databricks/sdk/errors/deserializer.py @@ -4,7 +4,7 @@ import re from typing import Optional -import requests +import requests # type: ignore[import-untyped] class _ErrorDeserializer(abc.ABC): diff --git a/databricks/sdk/errors/mapper.py b/databricks/sdk/errors/mapper.py index c3bb5b54c..597b22a8e 100644 --- a/databricks/sdk/errors/mapper.py +++ b/databricks/sdk/errors/mapper.py @@ -1,4 +1,4 @@ -import requests +import requests # type: ignore[import-untyped] from databricks.sdk.errors import platform from databricks.sdk.errors.base import DatabricksError diff --git a/databricks/sdk/errors/parser.py b/databricks/sdk/errors/parser.py index 2fefc4e2f..65b456cf3 100644 --- a/databricks/sdk/errors/parser.py +++ b/databricks/sdk/errors/parser.py @@ -1,7 +1,7 @@ import logging from typing import List, Optional -import requests +import requests # type: ignore[import-untyped] from ..logger import RoundTrip from .base import DatabricksError diff --git a/databricks/sdk/errors/private_link.py b/databricks/sdk/errors/private_link.py index e188b59e1..0aa7889bf 100644 --- a/databricks/sdk/errors/private_link.py +++ b/databricks/sdk/errors/private_link.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from urllib import parse -import requests +import requests # type: ignore[import-untyped] from ..environments import Cloud, get_environment_for_hostname from .platform import PermissionDenied diff --git a/databricks/sdk/logger/round_trip_logger.py b/databricks/sdk/logger/round_trip_logger.py index 7ff9d55c9..fbd80f1b1 100644 --- a/databricks/sdk/logger/round_trip_logger.py +++ b/databricks/sdk/logger/round_trip_logger.py @@ -2,7 +2,7 @@ import urllib.parse from typing import Any, Dict, List -import requests +import requests # type: ignore[import-untyped] class RoundTrip: @@ -38,13 +38,13 @@ def generate(self) -> str: url = urllib.parse.urlparse(request.url) query = "" if url.query: - query = f"?{urllib.parse.unquote(url.query)}" - sb = [f"{request.method} {urllib.parse.unquote(url.path)}{query}"] + query = f"?{urllib.parse.unquote(url.query)}" # type: ignore[arg-type] + sb = [f"{request.method} {urllib.parse.unquote(url.path)}{query}"] # type: ignore[arg-type] if self._debug_headers: for k, v in request.headers.items(): sb.append(f"> * {k}: {self._only_n_bytes(v, self._debug_truncate_bytes)}") if request.body: - sb.append("> [raw stream]" if self._raw else self._redacted_dump("> ", request.body)) + sb.append("> [raw stream]" if self._raw else self._redacted_dump("> ", request.body)) # type: ignore[arg-type] sb.append(f"< {self._response.status_code} {self._response.reason}") if self._raw and self._response.headers.get("Content-Type", None) != "application/json": # Raw streams with `Transfer-Encoding: chunked` do not have `Content-Type` header @@ -55,7 +55,7 @@ def generate(self) -> str: return "\n".join(sb) @staticmethod - def _mask(m: Dict[str, any]): + def _mask(m: Dict[str, any]): # type: ignore[valid-type] for k in m: if k in { "bytes_value", @@ -67,7 +67,7 @@ def _mask(m: Dict[str, any]): m[k] = "**REDACTED**" @staticmethod - def _map_keys(m: Dict[str, any]) -> List[str]: + def _map_keys(m: Dict[str, any]) -> List[str]: # type: ignore[valid-type] keys = list(m.keys()) keys.sort() return keys @@ -89,7 +89,7 @@ def _recursive_marshal_dict(self, m, budget) -> dict: return out def _recursive_marshal_list(self, s, budget) -> list: - out = [] + out = [] # type: ignore[var-annotated] for i in range(len(s)): if i > 0 >= budget: out.append("... (%d additional elements)" % (len(s) - len(out))) diff --git a/databricks/sdk/mixins/compute.py b/databricks/sdk/mixins/compute.py index 164887fb3..b8c3c9b8d 100644 --- a/databricks/sdk/mixins/compute.py +++ b/databricks/sdk/mixins/compute.py @@ -64,9 +64,9 @@ def __lt__(self, other: "SemVer"): if self.patch != other.patch: return self.patch < other.patch if self.pre_release != other.pre_release: - return self.pre_release < other.pre_release + return self.pre_release < other.pre_release # type: ignore[operator] if self.build != other.build: - return self.build < other.build + return self.build < other.build # type: ignore[operator] return False @@ -82,7 +82,7 @@ def select_spark_version( genomics: bool = False, gpu: bool = False, scala: str = "2.12", - spark_version: str = None, + spark_version: str = None, # type: ignore[assignment] photon: bool = False, graviton: bool = False, ) -> str: @@ -104,22 +104,22 @@ def select_spark_version( # Logic ported from https://github.com/databricks/databricks-sdk-go/blob/main/service/compute/spark_version.go versions = [] sv = self.spark_versions() - for version in sv.versions: - if "-scala" + scala not in version.key: + for version in sv.versions: # type: ignore[union-attr] + if "-scala" + scala not in version.key: # type: ignore[operator] continue matches = ( - ("apache-spark-" not in version.key) - and (("-ml-" in version.key) == ml) - and (("-hls-" in version.key) == genomics) - and (("-gpu-" in version.key) == gpu) - and (("-photon-" in version.key) == photon) - and (("-aarch64-" in version.key) == graviton) - and (("Beta" in version.name) == beta) + ("apache-spark-" not in version.key) # type: ignore[operator] + and (("-ml-" in version.key) == ml) # type: ignore[operator] + and (("-hls-" in version.key) == genomics) # type: ignore[operator] + and (("-gpu-" in version.key) == gpu) # type: ignore[operator] + and (("-photon-" in version.key) == photon) # type: ignore[operator] + and (("-aarch64-" in version.key) == graviton) # type: ignore[operator] + and (("Beta" in version.name) == beta) # type: ignore[operator] ) if matches and long_term_support: - matches = matches and (("LTS" in version.name) or ("-esr-" in version.key)) + matches = matches and (("LTS" in version.name) or ("-esr-" in version.key)) # type: ignore[operator] if matches and spark_version: - matches = matches and ("Apache Spark " + spark_version in version.name) + matches = matches and ("Apache Spark " + spark_version in version.name) # type: ignore[operator] if matches: versions.append(version.key) if len(versions) < 1: @@ -127,17 +127,17 @@ def select_spark_version( if len(versions) > 1: if not latest: raise ValueError("spark versions query returned multiple results") - versions = sorted(versions, key=SemVer.parse, reverse=True) - return versions[0] + versions = sorted(versions, key=SemVer.parse, reverse=True) # type: ignore[arg-type] + return versions[0] # type: ignore[return-value] @staticmethod def _node_sorting_tuple(item: compute.NodeType) -> tuple: local_disks = local_disk_size_gb = local_nvme_disk = local_nvme_disk_size_gb = 0 if item.node_instance_type is not None: - local_disks = item.node_instance_type.local_disks - local_nvme_disk = item.node_instance_type.local_nvme_disks - local_disk_size_gb = item.node_instance_type.local_disk_size_gb - local_nvme_disk_size_gb = item.node_instance_type.local_nvme_disk_size_gb + local_disks = item.node_instance_type.local_disks # type: ignore[assignment] + local_nvme_disk = item.node_instance_type.local_nvme_disks # type: ignore[assignment] + local_disk_size_gb = item.node_instance_type.local_disk_size_gb # type: ignore[assignment] + local_nvme_disk_size_gb = item.node_instance_type.local_nvme_disk_size_gb # type: ignore[assignment] return ( item.is_deprecated, item.num_cores, @@ -167,19 +167,19 @@ def _should_node_be_skipped(nt: compute.NodeType) -> bool: def select_node_type( self, - min_memory_gb: int = None, - gb_per_core: int = None, - min_cores: int = None, - min_gpus: int = None, - local_disk: bool = None, - local_disk_min_size: int = None, - category: str = None, - photon_worker_capable: bool = None, - photon_driver_capable: bool = None, - graviton: bool = None, - is_io_cache_enabled: bool = None, - support_port_forwarding: bool = None, - fleet: str = None, + min_memory_gb: int = None, # type: ignore[assignment] + gb_per_core: int = None, # type: ignore[assignment] + min_cores: int = None, # type: ignore[assignment] + min_gpus: int = None, # type: ignore[assignment] + local_disk: bool = None, # type: ignore[assignment] + local_disk_min_size: int = None, # type: ignore[assignment] + category: str = None, # type: ignore[assignment] + photon_worker_capable: bool = None, # type: ignore[assignment] + photon_driver_capable: bool = None, # type: ignore[assignment] + graviton: bool = None, # type: ignore[assignment] + is_io_cache_enabled: bool = None, # type: ignore[assignment] + support_port_forwarding: bool = None, # type: ignore[assignment] + fleet: str = None, # type: ignore[assignment] ) -> str: """Selects smallest available node type given the conditions. @@ -201,7 +201,7 @@ def select_node_type( """ # Logic ported from https://github.com/databricks/databricks-sdk-go/blob/main/service/clusters/node_type.go res = self.list_node_types() - types = sorted(res.node_types, key=self._node_sorting_tuple) + types = sorted(res.node_types, key=self._node_sorting_tuple) # type: ignore[arg-type] for nt in types: if self._should_node_be_skipped(nt): continue @@ -214,12 +214,12 @@ def select_node_type( continue if min_cores is not None and nt.num_cores < min_cores: continue - if (min_gpus is not None and nt.num_gpus < min_gpus) or (min_gpus == 0 and nt.num_gpus > 0): + if (min_gpus is not None and nt.num_gpus < min_gpus) or (min_gpus == 0 and nt.num_gpus > 0): # type: ignore[operator] continue if local_disk or local_disk_min_size is not None: instance_type = nt.node_instance_type - local_disks = int(instance_type.local_disks) if instance_type.local_disks else 0 - local_nvme_disks = int(instance_type.local_nvme_disks) if instance_type.local_nvme_disks else 0 + local_disks = int(instance_type.local_disks) if instance_type.local_disks else 0 # type: ignore[union-attr] + local_nvme_disks = int(instance_type.local_nvme_disks) if instance_type.local_nvme_disks else 0 # type: ignore[union-attr] if instance_type is None or (local_disks < 1 and local_nvme_disks < 1): continue local_disk_size_gb = instance_type.local_disk_size_gb if instance_type.local_disk_size_gb else 0 diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py index 819e14d4b..416b6da25 100644 --- a/databricks/sdk/mixins/files.py +++ b/databricks/sdk/mixins/files.py @@ -26,9 +26,9 @@ Iterable, Optional, Type, Union) from urllib import parse -import requests -import requests.adapters -from requests import RequestException +import requests # type: ignore[import-untyped] +import requests.adapters # type: ignore[import-untyped] +from requests import RequestException # type: ignore[import-untyped] from .._base_client import _BaseClient, _RawResponse, _StreamingResponse from .._property import _cached_property @@ -51,8 +51,8 @@ class _DbfsIO(BinaryIO): MAX_CHUNK_SIZE = 1024 * 1024 - _status: files.FileInfo = None - _created: files.CreateResponse = None + _status: files.FileInfo = None # type: ignore[assignment] + _created: files.CreateResponse = None # type: ignore[assignment] _offset = 0 _closed = False @@ -76,8 +76,8 @@ def __init__( else: raise IOError(f"need to open either for reading or writing") - def __enter__(self) -> Self: - return self + def __enter__(self) -> Self: # type: ignore[type-var] + return self # type: ignore[return-value] @property def name(self) -> str: @@ -91,7 +91,7 @@ def writable(self) -> bool: """ return self._created is not None - def write(self, buffer: bytes) -> int: + def write(self, buffer: bytes) -> int: # type: ignore[override] """Write bytes to file. :return: Return the number of bytes written. @@ -107,14 +107,14 @@ def write(self, buffer: bytes) -> int: if len(chunk) > self.MAX_CHUNK_SIZE: chunk = chunk[: self.MAX_CHUNK_SIZE] encoded = base64.b64encode(chunk).decode() - self._api.add_block(self._created.handle, encoded) + self._api.add_block(self._created.handle, encoded) # type: ignore[arg-type] total += len(chunk) return total def close(self) -> None: """Disable all I/O operations.""" if self.writable(): - self._api.close(self._created.handle) + self._api.close(self._created.handle) # type: ignore[arg-type] self._closed = True @property @@ -132,7 +132,7 @@ def __exit__( def readable(self) -> bool: return self._status is not None - def read(self, size: int = ...) -> bytes: + def read(self, size: int = ...) -> bytes: # type: ignore[assignment] """Read at most size bytes, returned as a bytes object. :param size: If the size argument is negative, read until EOF is reached. @@ -158,12 +158,12 @@ def read(self, size: int = ...) -> bytes: # and not the EOFError as in other SDKs return b"" - raw = base64.b64decode(response.data) - self._offset += response.bytes_read + raw = base64.b64decode(response.data) # type: ignore[arg-type] + self._offset += response.bytes_read # type: ignore[operator] return raw def __iter__(self) -> Iterator[bytes]: - while self._offset < self._status.file_size: + while self._offset < self._status.file_size: # type: ignore[operator] yield self.__next__() def __next__(self) -> bytes: @@ -179,7 +179,7 @@ def flush(self) -> None: def isatty(self) -> bool: return False - def readline(self, __limit: int = ...) -> AnyStr: + def readline(self, __limit: int = ...) -> AnyStr: # type: ignore[type-var] raise NotImplementedError def readlines(self, __hint: int = ...) -> list[AnyStr]: @@ -197,7 +197,7 @@ def tell(self) -> int: def truncate(self, __size: int | None = ...) -> int: raise NotImplementedError - def writelines(self, __lines: Iterable[AnyStr]) -> None: + def writelines(self, __lines: Iterable[AnyStr]) -> None: # type: ignore[override] raise NotImplementedError def __repr__(self) -> str: @@ -215,7 +215,7 @@ def __init__( write: bool, overwrite: bool, ): - self._buffer = [] + self._buffer = [] # type: ignore[var-annotated] self._api = api self._path = path self._read = read @@ -349,7 +349,7 @@ def exists(self) -> bool: ... @abstractmethod def open(self, *, read=False, write=False, overwrite=False): ... - def list(self, *, recursive=False) -> Generator[files.FileInfo, None, None]: ... + def list(self, *, recursive=False) -> Generator[files.FileInfo, None, None]: ... # type: ignore[empty-body] @abstractmethod def mkdir(self): ... @@ -359,11 +359,11 @@ def delete(self, *, recursive=False): ... @property def name(self) -> str: - return self._path.name + return self._path.name # type: ignore[attr-defined] @property def as_string(self) -> str: - return str(self._path) + return str(self._path) # type: ignore[attr-defined] class _LocalPath(_Path): @@ -380,8 +380,8 @@ def _is_local(self) -> bool: def _is_dbfs(self) -> bool: return False - def child(self, path: str) -> Self: - return _LocalPath(str(self._path / path)) + def child(self, path: str) -> Self: # type: ignore[type-var] + return _LocalPath(str(self._path / path)) # type: ignore[return-value] def _is_dir(self) -> bool: return self._path.is_dir() @@ -451,8 +451,8 @@ def _is_local(self) -> bool: def _is_dbfs(self) -> bool: return False - def child(self, path: str) -> Self: - return _VolumesPath(self._api, str(self._path / path)) + def child(self, path: str) -> Self: # type: ignore[type-var] + return _VolumesPath(self._api, str(self._path / path)) # type: ignore[return-value] def _is_dir(self) -> bool: try: @@ -487,7 +487,7 @@ def list(self, *, recursive=False) -> Generator[files.FileInfo, None, None]: path=self.as_string, is_dir=False, file_size=meta.content_length, - modification_time=meta.last_modified, + modification_time=meta.last_modified, # type: ignore[arg-type] ) return queue = deque([self]) @@ -495,7 +495,7 @@ def list(self, *, recursive=False) -> Generator[files.FileInfo, None, None]: next_path = queue.popleft() for file in self._api.list_directory_contents(next_path.as_string): if recursive and file.is_directory: - queue.append(self.child(file.name)) + queue.append(self.child(file.name)) # type: ignore[arg-type] if not recursive or not file.is_directory: yield files.FileInfo( path=file.path, @@ -528,14 +528,14 @@ def _is_local(self) -> bool: def _is_dbfs(self) -> bool: return True - def child(self, path: str) -> Self: + def child(self, path: str) -> Self: # type: ignore[type-var] child = self._path / path - return _DbfsPath(self._api, str(child)) + return _DbfsPath(self._api, str(child)) # type: ignore[return-value] def _is_dir(self) -> bool: try: remote = self._api.get_status(self.as_string) - return remote.is_dir + return remote.is_dir # type: ignore[return-value] except NotFound: return False @@ -573,7 +573,7 @@ def list(self, *, recursive=False) -> Generator[files.FileInfo, None, None]: next_path = queue.popleft() for file in self._api.list(next_path.as_string): if recursive and file.is_dir: - queue.append(self.child(file.path)) + queue.append(self.child(file.path)) # type: ignore[arg-type] if not recursive or not file.is_dir: yield file @@ -677,17 +677,17 @@ def copy(self, src: str, dst: str, *, recursive=False, overwrite=False): """Copy files between DBFS and local filesystems""" src = self._path(src) dst = self._path(dst) - if src.is_local and dst.is_local: + if src.is_local and dst.is_local: # type: ignore[attr-defined] raise IOError("both destinations are on local FS") - if dst.exists() and dst.is_dir: + if dst.exists() and dst.is_dir: # type: ignore[attr-defined] # if target is a folder, make file with the same name there - dst = dst.child(src.name) - if src.is_dir: - queue = [self._path(x.path) for x in src.list(recursive=recursive) if not x.is_dir] + dst = dst.child(src.name) # type: ignore[attr-defined] + if src.is_dir: # type: ignore[attr-defined] + queue = [self._path(x.path) for x in src.list(recursive=recursive) if not x.is_dir] # type: ignore[attr-defined] else: queue = [src] for child in queue: - child_dst = dst.child(os.path.relpath(child.as_string, src.as_string)) + child_dst = dst.child(os.path.relpath(child.as_string, src.as_string)) # type: ignore[attr-defined] with child.open(read=True) as reader: with child_dst.open(write=True, overwrite=overwrite) as writer: shutil.copyfileobj(reader, writer, length=_DbfsIO.MAX_CHUNK_SIZE) @@ -805,7 +805,7 @@ def download( ) wrapped_response = self._wrap_stream(file_path, initial_response) - initial_response.contents._response = wrapped_response + initial_response.contents._response = wrapped_response # type: ignore[union-attr] return initial_response def download_to( @@ -890,8 +890,8 @@ def _sequential_download_to_file( if_unmodified_since_timestamp=last_modified, ) wrapped_response = self._wrap_stream(remote_path, response, 0) - response.contents._response = wrapped_response - shutil.copyfileobj(response.contents, f) + response.contents._response = wrapped_response # type: ignore[union-attr] + shutil.copyfileobj(response.contents, f) # type: ignore[arg-type] def _do_parallel_download( self, remote_path: str, destination: str, parallelism: int, download_chunk: Callable @@ -909,7 +909,7 @@ def _do_parallel_download( fd, temp_file = mkstemp() # We are preallocate the file size to the same as the remote file to avoid seeking beyond the file size. - os.truncate(temp_file, file_size) + os.truncate(temp_file, file_size) # type: ignore[arg-type] os.close(fd) try: aborted = Event() @@ -935,7 +935,7 @@ def wrapped_download_chunk(start: int, end: int, last_modified: Optional[str], t # Start the threads to download parts of the file. for i in range(part_count): start = i * part_size - end = min(start + part_size - 1, file_size - 1) + end = min(start + part_size - 1, file_size - 1) # type: ignore[operator] futures.append(executor.submit(wrapped_download_chunk, start, end, last_modified, temp_file)) # Wait for all threads to complete and check for exceptions. @@ -1017,7 +1017,7 @@ def download_chunk(additional_headers: dict[str, str]) -> BinaryIO: def _get_optimized_performance_parameters_for_upload( self, content_length: Optional[int], part_size_overwrite: Optional[int] - ) -> (int, int): + ) -> (int, int): # type: ignore[syntax] """Get optimized part size and batch size for upload based on content length and provided part size. Returns tuple of (part_size, batch_size). @@ -1234,7 +1234,7 @@ def _initiate_multipart_upload(self, ctx: _UploadContext) -> dict: """Initiate a multipart upload and return the response.""" query = {"action": "initiate-upload"} if ctx.overwrite is not None: - query["overwrite"] = ctx.overwrite + query["overwrite"] = ctx.overwrite # type: ignore[assignment] # Method _api.do() takes care of retrying and will raise an exception in case of failure. initiate_upload_response = self._api.do( @@ -1366,8 +1366,8 @@ def _parallel_upload_from_file( _LOG.info(f"Falling back to single-shot upload with Files API: {e}") # Concatenate the buffered part and the rest of the stream. - with open(ctx.source_file_path, "rb") as f: - return self._single_thread_single_shot_upload(ctx, f) + with open(ctx.source_file_path, "rb") as f: # type: ignore[arg-type] + return self._single_thread_single_shot_upload(ctx, f) # type: ignore[arg-type] except Exception as e: _LOG.info(f"Aborting multipart upload on error: {e}") @@ -1381,8 +1381,8 @@ def _parallel_upload_from_file( elif initiate_upload_response.get("resumable_upload"): _LOG.warning("GCP does not support parallel resumable uploads, falling back to single-threaded upload") - with open(ctx.source_file_path, "rb") as f: - return self._upload_single_thread_with_known_size(ctx, f) + with open(ctx.source_file_path, "rb") as f: # type: ignore[arg-type] + return self._upload_single_thread_with_known_size(ctx, f) # type: ignore[arg-type] else: raise ValueError(f"Unexpected server response: {initiate_upload_response}") @@ -1400,19 +1400,19 @@ def _parallel_multipart_upload_from_file( session_token: str, ) -> None: # Calculate the number of parts. - file_size = os.path.getsize(ctx.source_file_path) + file_size = os.path.getsize(ctx.source_file_path) # type: ignore[arg-type] part_size = ctx.part_size num_parts = (file_size + part_size - 1) // part_size _LOG.debug(f"Uploading file of size {file_size} bytes in {num_parts} parts using {ctx.parallelism} threads") # Create queues and worker threads. - task_queue = Queue() - etags_result_queue = Queue() - exception_queue = Queue() + task_queue = Queue() # type: ignore[var-annotated] + etags_result_queue = Queue() # type: ignore[var-annotated] + exception_queue = Queue() # type: ignore[var-annotated] aborted = Event() workers = [ Thread(target=self._upload_file_consumer, args=(task_queue, etags_result_queue, exception_queue, aborted)) - for _ in range(ctx.parallelism) + for _ in range(ctx.parallelism) # type: ignore[arg-type] ] _LOG.debug(f"Starting {len(workers)} worker threads for parallel upload") @@ -1452,9 +1452,9 @@ def _parallel_multipart_upload_from_stream( cloud_provider_session: requests.Session, ) -> None: - task_queue = Queue(maxsize=ctx.parallelism) # Limit queue size to control memory usage - etags_result_queue = Queue() - exception_queue = Queue() + task_queue = Queue(maxsize=ctx.parallelism) # type: ignore[arg-type, var-annotated] # Limit queue size to control memory usage + etags_result_queue = Queue() # type: ignore[var-annotated] + exception_queue = Queue() # type: ignore[var-annotated] all_produced = Event() aborted = Event() @@ -1502,7 +1502,7 @@ def producer() -> None: target=self._upload_stream_consumer, args=(task_queue, etags_result_queue, exception_queue, all_produced, aborted), ) - for _ in range(ctx.parallelism) + for _ in range(ctx.parallelism) # type: ignore[arg-type] ] _LOG.debug(f"Starting {len(consumers)} worker threads for parallel upload") # Start producer and consumer threads @@ -1565,7 +1565,7 @@ def _upload_file_consumer( break try: - with open(part.ctx.source_file_path, "rb") as f: + with open(part.ctx.source_file_path, "rb") as f: # type: ignore[arg-type] f.seek(part.part_offset, os.SEEK_SET) part_content = BytesIO(f.read(part.part_size)) etag = self._do_upload_one_part( @@ -1669,7 +1669,7 @@ def _do_upload_one_part( required_headers = upload_part_url.get("headers", []) assert part_index == upload_part_url["part_number"] - headers: dict = {"Content-Type": "application/octet-stream"} + headers: dict = {"Content-Type": "application/octet-stream"} # type: ignore[no-redef] for h in required_headers: headers[h["name"]] = h["value"] @@ -1700,9 +1700,9 @@ def perform_upload() -> requests.Response: else: raise ValueError(f"Unsuccessful chunk upload: upload URL expired after {retry_count} retries") elif upload_response.status_code == 403: - raise FallbackToUploadUsingFilesApi(None, f"Direct upload forbidden: {upload_response.content}") + raise FallbackToUploadUsingFilesApi(None, f"Direct upload forbidden: {upload_response.content}") # type: ignore[str-bytes-safe] else: - message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" + message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" # type: ignore[str-bytes-safe] _LOG.warning(message) mapped_error = _error_mapper(upload_response, {}) raise mapped_error or ValueError(message) @@ -1793,7 +1793,7 @@ def _perform_multipart_upload( required_headers = upload_part_url.get("headers", []) assert current_part_number == upload_part_url["part_number"] - headers: dict = {"Content-Type": "application/octet-stream"} + headers: dict = {"Content-Type": "application/octet-stream"} # type: ignore[no-redef] for h in required_headers: headers[h["name"]] = h["value"] @@ -1846,10 +1846,10 @@ def perform(): # This might happen due to Azure firewall enabled for the customer bucket. # Let's fallback to using Files API which might be allowlisted to upload, passing # currently buffered (but not yet uploaded) part of the stream. - raise FallbackToUploadUsingFilesApi(buffer, f"Direct upload forbidden: {upload_response.content}") + raise FallbackToUploadUsingFilesApi(buffer, f"Direct upload forbidden: {upload_response.content}") # type: ignore[str-bytes-safe] else: - message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" + message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" # type: ignore[str-bytes-safe] _LOG.warning(message) mapped_error = _error_mapper(upload_response, {}) raise mapped_error or ValueError(message) @@ -1860,10 +1860,10 @@ def perform(): query = {"action": "complete-upload", "upload_type": "multipart", "session_token": session_token} headers = {"Content-Type": "application/json"} - body: dict = {} + body: dict = {} # type: ignore[no-redef] parts = [] - for etag in sorted(etags.items()): + for etag in sorted(etags.items()): # type: ignore[assignment] part = {"part_number": etag[0], "etag": etag[1]} parts.append(part) @@ -1913,7 +1913,7 @@ def _is_url_expired_response(response: requests.Response) -> bool: if code.text == "AuthenticationFailed": # Azure details = xml_root.find("AuthenticationErrorDetail") - if details is not None and "Signature not valid in the specified time frame" in details.text: + if details is not None and "Signature not valid in the specified time frame" in details.text: # type: ignore[operator] return True if code.text == "AccessDenied": @@ -2037,9 +2037,9 @@ def _perform_resumable_upload( else: # More chunks expected, let's upload current chunk (excluding read-ahead block). actual_chunk_length = ctx.part_size - file_size = "*" + file_size = "*" # type: ignore[assignment] - headers: dict = {"Content-Type": "application/octet-stream"} + headers: dict = {"Content-Type": "application/octet-stream"} # type: ignore[no-redef] for h in required_headers: headers[h["name"]] = h["value"] @@ -2094,7 +2094,7 @@ def perform(): and retry_count < self._config.files_ext_multipart_upload_max_retries ): retry_count += 1 - upload_response = retrieve_upload_status() + upload_response = retrieve_upload_status() # type: ignore[assignment] if not upload_response: # rethrow original exception raise e from None @@ -2144,7 +2144,7 @@ def perform(): raise AlreadyExists("The file being created already exists.") else: - message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" + message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" # type: ignore[str-bytes-safe] _LOG.warning(message) mapped_error = _error_mapper(upload_response, {}) raise mapped_error or ValueError(message) @@ -2213,7 +2213,7 @@ def _abort_multipart_upload( abort_url = abort_upload_url_node["url"] required_headers = abort_upload_url_node.get("headers", []) - headers: dict = {"Content-Type": "application/octet-stream"} + headers: dict = {"Content-Type": "application/octet-stream"} # type: ignore[no-redef] for h in required_headers: headers[h["name"]] = h["value"] @@ -2417,27 +2417,27 @@ def perform() -> requests.Response: stream=True, ) - csp_response: _RawResponse = self._retry_cloud_idempotent_operation(perform) + csp_response: _RawResponse = self._retry_cloud_idempotent_operation(perform) # type: ignore[assignment] # Mapping the error if the response is not successful. - if csp_response.status_code in (200, 201, 206): + if csp_response.status_code in (200, 201, 206): # type: ignore[attr-defined] resp = DownloadResponse( - content_length=int(csp_response.headers.get("content-length")), - content_type=csp_response.headers.get("content-type"), - last_modified=csp_response.headers.get("last-modified"), + content_length=int(csp_response.headers.get("content-length")), # type: ignore[attr-defined] + content_type=csp_response.headers.get("content-type"), # type: ignore[attr-defined] + last_modified=csp_response.headers.get("last-modified"), # type: ignore[attr-defined] contents=_StreamingResponse(csp_response, self._config.files_ext_client_download_streaming_chunk_size), ) return resp - elif csp_response.status_code == 403: + elif csp_response.status_code == 403: # type: ignore[attr-defined] # We got 403 failure when downloading the file. This might happen due to Azure firewall enabled for the customer bucket. # Let's fallback to using Files API which might be allowlisted to download. - raise FallbackToDownloadUsingFilesApi(f"Direct download forbidden: {csp_response.content}") + raise FallbackToDownloadUsingFilesApi(f"Direct download forbidden: {csp_response.content}") # type: ignore[attr-defined] else: message = ( - f"Unsuccessful download. Response status: {csp_response.status_code}, body: {csp_response.content}" + f"Unsuccessful download. Response status: {csp_response.status_code}, body: {csp_response.content}" # type: ignore[attr-defined] ) _LOG.warning(message) - mapped_error = _error_mapper(csp_response, {}) + mapped_error = _error_mapper(csp_response, {}) # type: ignore[arg-type] raise mapped_error or ValueError(message) def _init_download_response_mode_csp_with_fallback( @@ -2469,7 +2469,7 @@ def _wrap_stream( return _ResilientResponse( self, file_path, - download_response.last_modified, + download_response.last_modified, # type: ignore[arg-type] offset=start_byte_offset, underlying_response=underlying_response, ) @@ -2513,7 +2513,7 @@ class _ResilientIterator(Iterator): def _extract_raw_response( download_response: DownloadResponse, ) -> _RawResponse: - streaming_response: _StreamingResponse = download_response.contents + streaming_response: _StreamingResponse = download_response.contents # type: ignore[assignment] return streaming_response._response def __init__( @@ -2560,7 +2560,7 @@ def _recover(self) -> bool: self._recovers_without_progressing_count += 1 try: - self._underlying_iterator.close() + self._underlying_iterator.close() # type: ignore[attr-defined] _LOG.debug(f"Trying to recover from offset {self._offset}") @@ -2596,5 +2596,5 @@ def __next__(self) -> bytes: raise def close(self) -> None: - self._underlying_iterator.close() + self._underlying_iterator.close() # type: ignore[attr-defined] self._closed = True diff --git a/databricks/sdk/mixins/files_utils.py b/databricks/sdk/mixins/files_utils.py index fd07bcdce..356f1e4bc 100644 --- a/databricks/sdk/mixins/files_utils.py +++ b/databricks/sdk/mixins/files_utils.py @@ -138,13 +138,13 @@ def _get_stream_size(self, stream: BinaryIO) -> int: def _get_head_size(self) -> int: if self._head_size is None: - self._head_size = self._get_stream_size(self._head_stream) - return self._head_size + self._head_size = self._get_stream_size(self._head_stream) # type: ignore[assignment] + return self._head_size # type: ignore[return-value] def _get_tail_size(self) -> int: if self._tail_size is None: - self._tail_size = self._get_stream_size(self._tail_stream) - return self._tail_size + self._tail_size = self._get_stream_size(self._tail_stream) # type: ignore[assignment] + return self._tail_size # type: ignore[return-value] def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int: if not self.seekable(): @@ -217,10 +217,10 @@ def truncate(self, __size: Optional[int] = None) -> int: def writable(self) -> bool: return False - def write(self, __s: bytes) -> int: + def write(self, __s: bytes) -> int: # type: ignore[override] raise NotImplementedError("Stream is not writable") - def writelines(self, __lines: Iterable[bytes]) -> None: + def writelines(self, __lines: Iterable[bytes]) -> None: # type: ignore[override] raise NotImplementedError("Stream is not writable") def __next__(self) -> bytes: @@ -276,8 +276,8 @@ def get_url(self) -> tuple[CreateDownloadUrlResponse, int]: """ with self.lock: if self._current_url is None: - self._current_url = self._get_new_url_func() - return self._current_url, self.current_version + self._current_url = self._get_new_url_func() # type: ignore[assignment] + return self._current_url, self.current_version # type: ignore[return-value] def invalidate_url(self, version: int) -> None: """ diff --git a/databricks/sdk/mixins/jobs.py b/databricks/sdk/mixins/jobs.py index 1e6cf25d5..e37f022a3 100644 --- a/databricks/sdk/mixins/jobs.py +++ b/databricks/sdk/mixins/jobs.py @@ -51,11 +51,11 @@ def list( # fully fetch all top level arrays for each job in the list for job in jobs_list: if job.has_more: - job_from_get_call = self.get(job.job_id) - job.settings.tasks = job_from_get_call.settings.tasks - job.settings.job_clusters = job_from_get_call.settings.job_clusters - job.settings.parameters = job_from_get_call.settings.parameters - job.settings.environments = job_from_get_call.settings.environments + job_from_get_call = self.get(job.job_id) # type: ignore[arg-type] + job.settings.tasks = job_from_get_call.settings.tasks # type: ignore[union-attr] + job.settings.job_clusters = job_from_get_call.settings.job_clusters # type: ignore[union-attr] + job.settings.parameters = job_from_get_call.settings.parameters # type: ignore[union-attr] + job.settings.environments = job_from_get_call.settings.environments # type: ignore[union-attr] # Remove has_more fields for each job in the list. # This field in Jobs API 2.2 is useful for pagination. It indicates if there are more than 100 tasks or job_clusters in the job. # This function hides pagination details from the user. So the field does not play useful role here. @@ -134,7 +134,7 @@ def list_runs( # fully fetch all top level arrays for each run in the list for run in runs_list: if run.has_more: - run_from_get_call = self.get_run(run.run_id) + run_from_get_call = self.get_run(run.run_id) # type: ignore[arg-type] run.tasks = run_from_get_call.tasks run.job_clusters = run_from_get_call.job_clusters run.job_parameters = run_from_get_call.job_parameters @@ -190,13 +190,13 @@ def get_run( page_token=run.next_page_token, ) if is_paginating_iterations: - run.iterations.extend(next_run.iterations) + run.iterations.extend(next_run.iterations) # type: ignore[arg-type, union-attr] else: - run.tasks.extend(next_run.tasks) + run.tasks.extend(next_run.tasks) # type: ignore[arg-type, union-attr] # Each new page of runs/get response includes the next page of the job_clusters, job_parameters, and repair history. - run.job_clusters.extend(next_run.job_clusters) - run.job_parameters.extend(next_run.job_parameters) - run.repair_history.extend(next_run.repair_history) + run.job_clusters.extend(next_run.job_clusters) # type: ignore[arg-type, union-attr] + run.job_parameters.extend(next_run.job_parameters) # type: ignore[arg-type, union-attr] + run.repair_history.extend(next_run.repair_history) # type: ignore[arg-type, union-attr] run.next_page_token = next_run.next_page_token return run @@ -221,10 +221,10 @@ def get(self, job_id: int, *, page_token: Optional[str] = None) -> Job: while job.next_page_token is not None: next_job = super().get(job_id, page_token=job.next_page_token) # Each new page of jobs/get response includes the next page of the tasks, job_clusters, job_parameters, and environments. - job.settings.tasks.extend(next_job.settings.tasks) - job.settings.job_clusters.extend(next_job.settings.job_clusters) - job.settings.parameters.extend(next_job.settings.parameters) - job.settings.environments.extend(next_job.settings.environments) + job.settings.tasks.extend(next_job.settings.tasks) # type: ignore[arg-type, union-attr] + job.settings.job_clusters.extend(next_job.settings.job_clusters) # type: ignore[arg-type, union-attr] + job.settings.parameters.extend(next_job.settings.parameters) # type: ignore[arg-type, union-attr] + job.settings.environments.extend(next_job.settings.environments) # type: ignore[arg-type, union-attr] job.next_page_token = next_job.next_page_token return job diff --git a/databricks/sdk/mixins/open_ai_client.py b/databricks/sdk/mixins/open_ai_client.py index 4ab08ee5a..64a8c64ad 100644 --- a/databricks/sdk/mixins/open_ai_client.py +++ b/databricks/sdk/mixins/open_ai_client.py @@ -1,7 +1,7 @@ import json as js from typing import Dict, Optional -from requests import Response +from requests import Response # type: ignore[import-untyped] from databricks.sdk.service.serving import (ExternalFunctionRequestHttpMethod, HttpRequestResponse, @@ -13,7 +13,7 @@ class ServingEndpointsExt(ServingEndpointsAPI): # Using the HTTP Client to pass in the databricks authorization # This method will be called on every invocation, so when using with model serving will always get the refreshed token def _get_authorized_http_client(self): - import httpx + import httpx # type: ignore[import-not-found] class BearerAuth(httpx.Auth): @@ -67,7 +67,7 @@ def get_open_ai_client(self, **kwargs): ... ) """ try: - from openai import OpenAI + from openai import OpenAI # type: ignore[import-not-found] except Exception: raise ImportError( "Open AI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]`" @@ -96,7 +96,7 @@ def get_open_ai_client(self, **kwargs): def get_langchain_chat_open_ai_client(self, model): try: - from langchain_openai import ChatOpenAI + from langchain_openai import ChatOpenAI # type: ignore[import-not-found] except Exception: raise ImportError( "Langchain Open AI is not installed. Please install the Databricks SDK with the following command `pip install databricks-sdk[openai]` and ensure you are using python>3.7" @@ -109,15 +109,15 @@ def get_langchain_chat_open_ai_client(self, model): http_client=self._get_authorized_http_client(), ) - def http_request( + def http_request( # type: ignore[override] self, conn: str, method: ExternalFunctionRequestHttpMethod, path: str, *, - headers: Optional[Dict[str, str]] = None, - json: Optional[Dict[str, str]] = None, - params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, # type: ignore[override] + json: Optional[Dict[str, str]] = None, # type: ignore[override] + params: Optional[Dict[str, str]] = None, # type: ignore[override] ) -> Response: """Make external services call using the credentials stored in UC Connection. **NOTE:** Experimental: This API may change or be removed in a future release without warning. @@ -164,7 +164,7 @@ def http_request( # Read the content from the HttpRequestResponse object if hasattr(server_response, "contents") and hasattr(server_response.contents, "read"): - raw_content = server_response.contents.read() # Read the bytes + raw_content = server_response.contents.read() # type: ignore[union-attr] # Read the bytes else: raise ValueError("Invalid response from the server.") diff --git a/databricks/sdk/mixins/sharing.py b/databricks/sdk/mixins/sharing.py index 65e03d665..457683d62 100644 --- a/databricks/sdk/mixins/sharing.py +++ b/databricks/sdk/mixins/sharing.py @@ -27,7 +27,7 @@ def list(self, *, max_results: Optional[int] = None, page_token: Optional[str] = if max_results is not None: query["max_results"] = max_results if page_token is not None: - query["page_token"] = page_token + query["page_token"] = page_token # type: ignore[assignment] headers = { "Accept": "application/json", } diff --git a/databricks/sdk/mixins/workspace.py b/databricks/sdk/mixins/workspace.py index f62ad5bff..6a5a74566 100644 --- a/databricks/sdk/mixins/workspace.py +++ b/databricks/sdk/mixins/workspace.py @@ -33,7 +33,7 @@ def list( path, queue = queue[0], queue[1:] for object_info in parent_list(path, notebooks_modified_after=notebooks_modified_after): if recursive and object_info.object_type == ObjectType.DIRECTORY: - queue.append(object_info.path) + queue.append(object_info.path) # type: ignore[arg-type] continue yield object_info diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index f18f0cd51..095d3f7a0 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -16,8 +16,8 @@ from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Any, Dict, List, Optional -import requests -import requests.auth +import requests # type: ignore[import-untyped] +import requests.auth # type: ignore[import-untyped] from ._base_client import _BaseClient, _fix_host_if_needed @@ -64,8 +64,8 @@ class OidcEndpoints: @staticmethod def from_dict(d: dict) -> "OidcEndpoints": return OidcEndpoints( - authorization_endpoint=d.get("authorization_endpoint"), - token_endpoint=d.get("token_endpoint"), + authorization_endpoint=d.get("authorization_endpoint"), # type: ignore[arg-type] + token_endpoint=d.get("token_endpoint"), # type: ignore[arg-type] ) def as_dict(self) -> dict: @@ -180,7 +180,7 @@ def retrieve_token( if use_header: auth = requests.auth.HTTPBasicAuth(client_id, client_secret) else: - auth = IgnoreNetrcAuth() + auth = IgnoreNetrcAuth() # type: ignore[assignment] resp = requests.post(token_url, params, auth=auth, headers=headers) if not resp.ok: if resp.headers["Content-Type"].startswith("application/json"): @@ -376,10 +376,10 @@ def get_account_endpoints(host: str, account_id: str, client: _BaseClient = _Bas :param account_id: The account ID. :return: The account's OIDC endpoints. """ - host = _fix_host_if_needed(host) + host = _fix_host_if_needed(host) # type: ignore[assignment] oidc = f"{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server" resp = client.do("GET", oidc) - return OidcEndpoints.from_dict(resp) + return OidcEndpoints.from_dict(resp) # type: ignore[arg-type] def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints: @@ -388,10 +388,10 @@ def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> O :param host: The Databricks workspace host. :return: The workspace's OIDC endpoints. """ - host = _fix_host_if_needed(host) + host = _fix_host_if_needed(host) # type: ignore[assignment] oidc = f"{host}/oidc/.well-known/oauth-authorization-server" resp = client.do("GET", oidc) - return OidcEndpoints.from_dict(resp) + return OidcEndpoints.from_dict(resp) # type: ignore[arg-type] def get_azure_entra_id_workspace_endpoints( @@ -404,7 +404,7 @@ def get_azure_entra_id_workspace_endpoints( :return: The OIDC endpoints for the workspace's Azure Entra ID tenant. """ # In Azure, this workspace endpoint redirects to the Entra ID authorization endpoint - host = _fix_host_if_needed(host) + host = _fix_host_if_needed(host) # type: ignore[assignment] res = requests.get(f"{host}/oidc/oauth2/v2.0/authorize", allow_redirects=False) real_auth_url = res.headers.get("location") if not real_auth_url: @@ -421,8 +421,8 @@ def __init__( token: Token, token_endpoint: str, client_id: str, - client_secret: str = None, - redirect_url: str = None, + client_secret: str = None, # type: ignore[assignment] + redirect_url: str = None, # type: ignore[assignment] disable_async: bool = True, ): self._token_endpoint = token_endpoint @@ -442,8 +442,8 @@ def from_dict( raw: dict, token_endpoint: str, client_id: str, - client_secret: str = None, - redirect_url: str = None, + client_secret: str = None, # type: ignore[assignment] + redirect_url: str = None, # type: ignore[assignment] ) -> "SessionCredentials": return SessionCredentials( token=Token.from_dict(raw["token"]), @@ -498,7 +498,7 @@ def __init__( redirect_url: str, token_endpoint: str, client_id: str, - client_secret: str = None, + client_secret: str = None, # type: ignore[assignment] ) -> None: self._verifier = verifier self._state = state @@ -523,7 +523,7 @@ def authorization_url(self) -> str: return self._authorization_url @staticmethod - def from_dict(raw: dict, client_secret: str = None) -> "Consent": + def from_dict(raw: dict, client_secret: str = None) -> "Consent": # type: ignore[assignment] return Consent( raw["state"], raw["verifier"], @@ -538,12 +538,12 @@ def launch_external_browser(self) -> SessionCredentials: redirect_url = urllib.parse.urlparse(self._redirect_url) if redirect_url.hostname not in ("localhost", "127.0.0.1"): raise ValueError(f"cannot listen on {redirect_url.hostname}") - feedback = [] + feedback = [] # type: ignore[var-annotated] logger.info(f"Opening {self._authorization_url} in a browser") webbrowser.open_new(self._authorization_url) port = redirect_url.port handler_factory = functools.partial(_OAuthCallback, feedback) - with HTTPServer(("localhost", port), handler_factory) as httpd: + with HTTPServer(("localhost", port), handler_factory) as httpd: # type: ignore[arg-type] logger.info(f"Waiting for redirect to http://localhost:{port}") httpd.handle_request() if not feedback: @@ -567,7 +567,7 @@ def exchange(self, code: str, state: str) -> SessionCredentials: "code_verifier": self._verifier, "code": code, } - headers = {} + headers = {} # type: ignore[var-annotated] while True: try: token = retrieve_token( @@ -620,8 +620,8 @@ def __init__( oidc_endpoints: OidcEndpoints, redirect_url: str, client_id: str, - scopes: List[str] = None, - client_secret: str = None, + scopes: List[str] = None, # type: ignore[assignment] + client_secret: str = None, # type: ignore[assignment] ): if not scopes: # all-apis ensures that the returned OAuth token can be used with all APIs, aside @@ -642,14 +642,14 @@ def from_host( client_id: str, redirect_url: str, *, - scopes: List[str] = None, - client_secret: str = None, + scopes: List[str] = None, # type: ignore[assignment] + client_secret: str = None, # type: ignore[assignment] ) -> "OAuthClient": from .core import Config from .credentials_provider import credentials_strategy @credentials_strategy("noop", []) - def noop_credentials(_: any): + def noop_credentials(_: any): # type: ignore[valid-type] return lambda: {} config = Config(host=host, credentials_strategy=noop_credentials) @@ -705,8 +705,8 @@ class ClientCredentials(Refreshable): client_id: str client_secret: str token_url: str - endpoint_params: dict = None - scopes: List[str] = None + endpoint_params: dict = None # type: ignore[assignment] + scopes: List[str] = None # type: ignore[assignment] use_params: bool = False use_header: bool = False disable_async: bool = True @@ -776,8 +776,8 @@ def load(self) -> Optional[SessionCredentials]: raw, token_endpoint=self._oidc_endpoints.token_endpoint, client_id=self._client_id, - client_secret=self._client_secret, - redirect_url=self._redirect_url, + client_secret=self._client_secret, # type: ignore[arg-type] + redirect_url=self._redirect_url, # type: ignore[arg-type] ) except Exception: return None diff --git a/databricks/sdk/oidc.py b/databricks/sdk/oidc.py index c90313a4c..f29786c8f 100644 --- a/databricks/sdk/oidc.py +++ b/databricks/sdk/oidc.py @@ -194,7 +194,7 @@ def token(self) -> oauth.Token: # It exists to make it easier to test. def _exchange_id_token(self, id_token: IdToken) -> oauth.Token: client = oauth.ClientCredentials( - client_id=self._client_id, + client_id=self._client_id, # type: ignore[arg-type] client_secret="", # there is no (rotatable) secrets in the OIDC flow token_url=self._token_endpoint, endpoint_params={ diff --git a/databricks/sdk/oidc_token_supplier.py b/databricks/sdk/oidc_token_supplier.py index bd050dd5f..413c13f68 100644 --- a/databricks/sdk/oidc_token_supplier.py +++ b/databricks/sdk/oidc_token_supplier.py @@ -2,7 +2,7 @@ import os from typing import Optional -import requests +import requests # type: ignore[import-untyped] logger = logging.getLogger("databricks.sdk") diff --git a/databricks/sdk/retries.py b/databricks/sdk/retries.py index a6cf5d8dc..a825007ec 100644 --- a/databricks/sdk/retries.py +++ b/databricks/sdk/retries.py @@ -76,7 +76,7 @@ class RetryError(Exception): def __init__(self, err: Exception, halt: bool = False): self.err = err - self.halt = halt + self.halt = halt # type: ignore[assignment, method-assign] super().__init__(str(err)) @staticmethod @@ -149,9 +149,9 @@ def check_operation(): result, err = fn() if err is None: - return result + return result # type: ignore[return-value] - if err.halt: + if err.halt: # type: ignore[truthy-function] raise err.err # Continue polling. diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index adf26c707..f64c0cd3e 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -26,7 +26,7 @@ try: # We don't want to expose additional entity to user namespace, so # a workaround here for exposing required information in notebook environment - from dbruntime.sdk_credential_provider import init_runtime_native_auth + from dbruntime.sdk_credential_provider import init_runtime_native_auth # type: ignore[import-not-found] logger.debug("runtime SDK credential provider available") dbruntime_objects.append("init_runtime_native_auth") @@ -38,7 +38,7 @@ def init_runtime_repl_auth(): try: - from dbruntime.databricks_repl_context import get_context + from dbruntime.databricks_repl_context import get_context # type: ignore[import-not-found] ctx = get_context() if ctx is None: @@ -60,7 +60,7 @@ def inner() -> Dict[str, str]: def init_runtime_legacy_auth(): try: - import IPython + import IPython # type: ignore[import-not-found] ip_shell = IPython.get_ipython() if ip_shell is None: @@ -88,7 +88,7 @@ def inner() -> Dict[str, str]: try: # Internal implementation # Separated from above for backward compatibility - from dbruntime import UserNamespaceInitializer + from dbruntime import UserNamespaceInitializer # type: ignore[import-not-found] userNamespaceGlobals = UserNamespaceInitializer.getOrCreate().get_namespace_globals() _globals = globals() @@ -108,7 +108,7 @@ def inner() -> Dict[str, str]: # mannaer. We separate them to try to get as many of them working as possible try: # We expect this to fail and only do this for providing types - from pyspark.sql.context import SQLContext + from pyspark.sql.context import SQLContext # type: ignore[import-not-found] sqlContext: SQLContext = None # type: ignore table = sqlContext.table @@ -187,7 +187,7 @@ def displayHTML(html) -> None: # type: ignore dbutils_type = Union[dbutils_stub.dbutils, RemoteDbUtils] dbutils = RemoteDbUtils() - dbutils = cast(dbutils_type, dbutils) + dbutils = cast(dbutils_type, dbutils) # type: ignore[assignment] # We do this to prevent importing widgets implementation prematurely # The widget import should prompt users to use the implementation diff --git a/databricks/sdk/runtime/dbutils_stub.py b/databricks/sdk/runtime/dbutils_stub.py index 363436e1f..04038b646 100644 --- a/databricks/sdk/runtime/dbutils_stub.py +++ b/databricks/sdk/runtime/dbutils_stub.py @@ -28,28 +28,28 @@ class credentials: """ @staticmethod - def assumeRole(role: str) -> bool: + def assumeRole(role: str) -> bool: # type: ignore[empty-body] """ Sets the role ARN to assume when looking for credentials to authenticate with S3 """ ... @staticmethod - def showCurrentRole() -> typing.List[str]: + def showCurrentRole() -> typing.List[str]: # type: ignore[empty-body] """ Shows the currently set role """ ... @staticmethod - def showRoles() -> typing.List[str]: + def showRoles() -> typing.List[str]: # type: ignore[empty-body] """ Shows the set of possibly assumed roles """ ... @staticmethod - def getCurrentCredentials() -> typing.Mapping[str, str]: ... + def getCurrentCredentials() -> typing.Mapping[str, str]: ... # type: ignore[empty-body] class data: """ @@ -57,7 +57,7 @@ class data: """ @staticmethod - def summarize(df: any, precise: bool = False) -> None: + def summarize(df: any, precise: bool = False) -> None: # type: ignore[valid-type] """Summarize a Spark/pandas/Koalas DataFrame and visualize the statistics to get quick insights. Example: dbutils.data.summarize(df) @@ -79,49 +79,49 @@ class fs: """ @staticmethod - def cp(source: str, dest: str, recurse: bool = False) -> bool: + def cp(source: str, dest: str, recurse: bool = False) -> bool: # type: ignore[empty-body] """ Copies a file or directory, possibly across FileSystems """ ... @staticmethod - def head(file: str, max_bytes: int = 65536) -> str: + def head(file: str, max_bytes: int = 65536) -> str: # type: ignore[empty-body] """ Returns up to the first 'maxBytes' bytes of the given file as a String encoded in UTF-8 """ ... @staticmethod - def ls(path: str) -> typing.List[FileInfo]: + def ls(path: str) -> typing.List[FileInfo]: # type: ignore[empty-body] """ Lists the contents of a directory """ ... @staticmethod - def mkdirs(dir: str) -> bool: + def mkdirs(dir: str) -> bool: # type: ignore[empty-body] """ Creates the given directory if it does not exist, also creating any necessary parent directories """ ... @staticmethod - def mv(source: str, dest: str, recurse: bool = False) -> bool: + def mv(source: str, dest: str, recurse: bool = False) -> bool: # type: ignore[empty-body] """ Moves a file or directory, possibly across FileSystems """ ... @staticmethod - def put(file: str, contents: str, overwrite: bool = False) -> bool: + def put(file: str, contents: str, overwrite: bool = False) -> bool: # type: ignore[empty-body] """ Writes the given String out to a file, encoded in UTF-8 """ ... @staticmethod - def rm(dir: str, recurse: bool = False) -> bool: + def rm(dir: str, recurse: bool = False) -> bool: # type: ignore[empty-body] """ Removes a file or directory """ @@ -140,7 +140,7 @@ def uncacheFiles(*files): ... def uncacheTable(name: str): ... @staticmethod - def mount( + def mount( # type: ignore[empty-body] source: str, mount_point: str, encryption_type: str = "", @@ -153,7 +153,7 @@ def mount( ... @staticmethod - def updateMount( + def updateMount( # type: ignore[empty-body] source: str, mount_point: str, encryption_type: str = "", @@ -166,21 +166,21 @@ def updateMount( ... @staticmethod - def mounts() -> typing.List[MountInfo]: + def mounts() -> typing.List[MountInfo]: # type: ignore[empty-body] """ Displays information about what is mounted within DBFS """ ... @staticmethod - def refreshMounts() -> bool: + def refreshMounts() -> bool: # type: ignore[empty-body] """ Forces all machines in this cluster to refresh their mount cache, ensuring they receive the most recent information """ ... @staticmethod - def unmount(mount_point: str) -> bool: + def unmount(mount_point: str) -> bool: # type: ignore[empty-body] """ Deletes a DBFS mount point """ @@ -200,8 +200,8 @@ class taskValues: def get( taskKey: str, key: str, - default: any = None, - debugValue: any = None, + default: any = None, # type: ignore[valid-type] + debugValue: any = None, # type: ignore[valid-type] ) -> None: """ Returns the latest task value that belongs to the current job run @@ -209,7 +209,7 @@ def get( ... @staticmethod - def set(key: str, value: any) -> None: + def set(key: str, value: any) -> None: # type: ignore[valid-type] """ Sets a task value on the current task run """ @@ -240,7 +240,7 @@ def exit(value: str) -> None: ... @staticmethod - def run( + def run( # type: ignore[empty-body] path: str, timeout_seconds: int, arguments: typing.Mapping[str, str], @@ -256,25 +256,25 @@ class secrets: """ @staticmethod - def get(scope: str, key: str) -> str: + def get(scope: str, key: str) -> str: # type: ignore[empty-body] """ Gets the string representation of a secret value with scope and key """ ... @staticmethod - def getBytes(self, scope: str, key: str) -> bytes: + def getBytes(self, scope: str, key: str) -> bytes: # type: ignore[empty-body] """Gets the bytes representation of a secret value for the specified scope and key.""" @staticmethod - def list(scope: str) -> typing.List[SecretMetadata]: + def list(scope: str) -> typing.List[SecretMetadata]: # type: ignore[empty-body] """ Lists secret metadata for secrets within a scope """ ... @staticmethod - def listScopes() -> typing.List[SecretScope]: + def listScopes() -> typing.List[SecretScope]: # type: ignore[empty-body] """ Lists secret scopes """ @@ -286,7 +286,7 @@ class widgets: """ @staticmethod - def get(name: str) -> str: + def get(name: str) -> str: # type: ignore[empty-body] """Returns the current value of a widget with give name. :param name: Name of the argument to be accessed :return: Current value of the widget or default value @@ -303,7 +303,7 @@ def getArgument(name: str, defaultValue: typing.Optional[str] = None) -> typing. ... @staticmethod - def text(name: str, defaultValue: str, label: str = None): + def text(name: str, defaultValue: str, label: str = None): # type: ignore[assignment] """Creates a text input widget with given name, default value and optional label for display :param name: Name of argument associated with the new input widget @@ -317,7 +317,7 @@ def dropdown( name: str, defaultValue: str, choices: typing.List[str], - label: str = None, + label: str = None, # type: ignore[assignment] ): """Creates a dropdown input widget with given specification. :param name: Name of argument associated with the new input widget diff --git a/pyproject.toml b/pyproject.toml index 5535d2186..2d88f501e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dev = [ "pycodestyle", "autoflake", "isort", + "mypy", "wheel", "ipython", "ipywidgets", @@ -85,4 +86,22 @@ include = ["."] exclude = ["**/node_modules", "**/__pycache__"] reportMissingImports = true reportMissingTypeStubs = false -pythonVersion = "3.7" \ No newline at end of file +pythonVersion = "3.7" + +[tool.mypy] +strict = true +python_version = "3.9" +files = ["databricks", "tests"] + +# Don't type-check generated files +[[tool.mypy.overrides]] +module = [ + "databricks.sdk.__init__", + "databricks.sdk.errors.overrides", + "databricks.sdk.errors.platform", + "databricks.sdk.service.*", + "tests.databricks.sdk.service.*", + "tests.generated.*", +] +ignore_errors = true +follow_imports = "skip" \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index cb2efc0a2..14e23d708 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,8 +2,8 @@ import os import platform -import pytest as pytest -from pyfakefs.fake_filesystem_unittest import Patcher +import pytest as pytest # type: ignore[import-not-found] +from pyfakefs.fake_filesystem_unittest import Patcher # type: ignore[import-not-found] from databricks.sdk.core import Config from databricks.sdk.credentials_provider import credentials_strategy @@ -13,7 +13,7 @@ @credentials_strategy("noop", []) -def noop_credentials(_: any): +def noop_credentials(_: any): # type: ignore[valid-type] return lambda: {} diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 55114bd84..cad5b7c72 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -5,7 +5,7 @@ import string import sys -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk import AccountClient, FilesAPI, FilesExt, WorkspaceClient from databricks.sdk.service.catalog import VolumeType diff --git a/tests/integration/test_auth.py b/tests/integration/test_auth.py index 14aea59bf..9a296dfa5 100644 --- a/tests/integration/test_auth.py +++ b/tests/integration/test_auth.py @@ -10,7 +10,7 @@ from functools import partial from pathlib import Path -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk import AccountClient, WorkspaceClient from databricks.sdk.service import iam, oauth2 diff --git a/tests/integration/test_clusters.py b/tests/integration/test_clusters.py index dd388d2ed..042b64576 100644 --- a/tests/integration/test_clusters.py +++ b/tests/integration/test_clusters.py @@ -1,7 +1,7 @@ import logging from datetime import timedelta -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk.core import DatabricksError from databricks.sdk.service.compute import EventType diff --git a/tests/integration/test_commands.py b/tests/integration/test_commands.py index e60302818..309bced48 100644 --- a/tests/integration/test_commands.py +++ b/tests/integration/test_commands.py @@ -1,4 +1,4 @@ -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk.core import DatabricksError diff --git a/tests/integration/test_dbconnect.py b/tests/integration/test_dbconnect.py index ecc82855b..25dace7a7 100644 --- a/tests/integration/test_dbconnect.py +++ b/tests/integration/test_dbconnect.py @@ -1,4 +1,4 @@ -import pytest +import pytest # type: ignore[import-not-found] DBCONNECT_DBR_CLIENT = { "13.3": "13.3.3", @@ -51,7 +51,7 @@ def setup_dbconnect_test(request, env_or_skip, restorable_env): @pytest.mark.xdist_group(name="databricks-connect") def test_dbconnect_initialisation(w, setup_dbconnect_test): reload_modules("databricks.connect") - from databricks.connect import DatabricksSession + from databricks.connect import DatabricksSession # type: ignore[import-not-found] spark = DatabricksSession.builder.getOrCreate() assert spark.sql("SELECT 1").collect()[0][0] == 1 diff --git a/tests/integration/test_dbutils.py b/tests/integration/test_dbutils.py index feafac00a..c52c7deb0 100644 --- a/tests/integration/test_dbutils.py +++ b/tests/integration/test_dbutils.py @@ -2,7 +2,7 @@ import logging import os -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk.core import DatabricksError from databricks.sdk.errors import NotFound diff --git a/tests/integration/test_deployment.py b/tests/integration/test_deployment.py index 2071645d2..945dd0d04 100644 --- a/tests/integration/test_deployment.py +++ b/tests/integration/test_deployment.py @@ -1,6 +1,6 @@ import logging -import pytest +import pytest # type: ignore[import-not-found] def test_workspaces(a): diff --git a/tests/integration/test_external_browser.py b/tests/integration/test_external_browser.py index 883069217..54b450d04 100644 --- a/tests/integration/test_external_browser.py +++ b/tests/integration/test_external_browser.py @@ -1,4 +1,4 @@ -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk import WorkspaceClient diff --git a/tests/integration/test_files.py b/tests/integration/test_files.py index 348f88b05..32c135528 100644 --- a/tests/integration/test_files.py +++ b/tests/integration/test_files.py @@ -5,7 +5,7 @@ import time from typing import Callable, List, Tuple, Union -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk.core import DatabricksError from databricks.sdk.service.catalog import VolumeType diff --git a/tests/integration/test_iam.py b/tests/integration/test_iam.py index cc40c039c..5df09fb69 100644 --- a/tests/integration/test_iam.py +++ b/tests/integration/test_iam.py @@ -1,4 +1,4 @@ -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk import errors from databricks.sdk.core import DatabricksError diff --git a/tests/integration/test_jobs.py b/tests/integration/test_jobs.py index cfc8de0b7..be733a84a 100644 --- a/tests/integration/test_jobs.py +++ b/tests/integration/test_jobs.py @@ -36,7 +36,7 @@ def test_submitting_jobs(w, random, env_or_skip): logging.info(f"starting to poll: {waiter.run_id}") def print_status(run: jobs.Run): - statuses = [f"{t.task_key}: {t.state.life_cycle_state}" for t in run.tasks] + statuses = [f"{t.task_key}: {t.state.life_cycle_state}" for t in run.tasks] # type: ignore[union-attr] logging.info(f'workflow intermediate status: {", ".join(statuses)}') run = waiter.result(timeout=datetime.timedelta(minutes=15), callback=print_status) diff --git a/tests/test_auth_manual_tests.py b/tests/test_auth_manual_tests.py index f66e92ea8..e4a027d48 100644 --- a/tests/test_auth_manual_tests.py +++ b/tests/test_auth_manual_tests.py @@ -1,4 +1,4 @@ -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk.core import Config diff --git a/tests/test_base_client.py b/tests/test_base_client.py index 8b3501d49..f99e3b8fa 100644 --- a/tests/test_base_client.py +++ b/tests/test_base_client.py @@ -4,8 +4,8 @@ from typing import Callable, Iterator, List, Optional, Tuple, Type from unittest.mock import Mock -import pytest -from requests import PreparedRequest, Response, Timeout +import pytest # type: ignore[import-not-found] +from requests import PreparedRequest, Response, Timeout # type: ignore[import-untyped] from databricks.sdk import errors, useragent from databricks.sdk._base_client import (_BaseClient, _RawResponse, @@ -437,7 +437,7 @@ def get_data(self): @classmethod def create_non_seekable_stream(cls, data: bytes): result = io.BytesIO(data) - result.seekable = lambda: False # makes the stream appear non-seekable + result.seekable = lambda: False # type: ignore[method-assign] # makes the stream appear non-seekable return result @@ -550,7 +550,7 @@ def test_rewind_seekable_stream(test_case: RetryTestCase, failure: Tuple[Callabl session = MockSession(failure_count, failure[0]) client = _BaseClient() - client._session = session + client._session = session # type: ignore[assignment] def do(): client.do("POST", "test.com/foo", data=data) diff --git a/tests/test_client.py b/tests/test_client.py index 7eaf308f1..833428e9c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,6 +1,6 @@ from unittest.mock import create_autospec -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk import WorkspaceClient diff --git a/tests/test_compute_mixins.py b/tests/test_compute_mixins.py index ec895b022..426f51423 100644 --- a/tests/test_compute_mixins.py +++ b/tests/test_compute_mixins.py @@ -1,4 +1,4 @@ -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk.mixins.compute import SemVer diff --git a/tests/test_config.py b/tests/test_config.py index 59fbf8712..6249f1766 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,7 +5,7 @@ import string from datetime import datetime -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk import oauth, useragent from databricks.sdk.config import Config, with_product, with_user_agent_extra diff --git a/tests/test_core.py b/tests/test_core.py index cc8ed921d..fdef090aa 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -6,7 +6,7 @@ from datetime import datetime from http.server import BaseHTTPRequestHandler -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk import WorkspaceClient, errors, useragent from databricks.sdk.core import ApiClient, Config, DatabricksError diff --git a/tests/test_dbfs_mixins.py b/tests/test_dbfs_mixins.py index 4332c8475..3e02da283 100644 --- a/tests/test_dbfs_mixins.py +++ b/tests/test_dbfs_mixins.py @@ -1,4 +1,4 @@ -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk.errors import NotFound from databricks.sdk.mixins.files import (DbfsExt, _DbfsPath, _LocalPath, @@ -14,7 +14,7 @@ def test_moving_dbfs_file_to_local_dir(config, tmp_path, mocker): return_value=FileInfo(path="a", is_dir=False, file_size=4), ) - def fake_read(path: str, *, length: int = None, offset: int = None): + def fake_read(path: str, *, length: int = None, offset: int = None): # type: ignore[assignment] assert path == "a" assert length == 1048576 if not offset: diff --git a/tests/test_dbutils.py b/tests/test_dbutils.py index 1b9a97f14..75d83c5fd 100644 --- a/tests/test_dbutils.py +++ b/tests/test_dbutils.py @@ -1,4 +1,4 @@ -import pytest as pytest +import pytest as pytest # type: ignore[import-not-found] from databricks.sdk.dbutils import FileInfo as DBUtilsFileInfo from databricks.sdk.service.files import FileInfo, ReadResponse diff --git a/tests/test_errors.py b/tests/test_errors.py index 57e045c3a..1a1fc5e1a 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -3,8 +3,8 @@ from dataclasses import dataclass, field from typing import Any, List, Optional -import pytest -import requests +import pytest # type: ignore[import-not-found] +import requests # type: ignore[import-untyped] from databricks.sdk import errors from databricks.sdk.errors import details @@ -47,7 +47,7 @@ def fake_valid_response( if error_code: body["error_code"] = error_code if len(details) > 0: - body["details"] = details + body["details"] = details # type: ignore[assignment] return fake_response(method, status_code, json.dumps(body), path) @@ -366,7 +366,7 @@ def test_get_api_error(test_case: TestCase): parser = errors._Parser() with pytest.raises(errors.DatabricksError) as e: - raise parser.get_api_error(test_case.response) + raise parser.get_api_error(test_case.response) # type: ignore[misc] assert isinstance(e.value, test_case.want_err_type) assert str(e.value) == test_case.want_message diff --git a/tests/test_fieldmask.py b/tests/test_fieldmask.py index 3505bd90c..31953e3ce 100644 --- a/tests/test_fieldmask.py +++ b/tests/test_fieldmask.py @@ -1,4 +1,4 @@ -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk.common.types.fieldmask import FieldMask diff --git a/tests/test_files.py b/tests/test_files.py index 64947e9b8..4849a6eec 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -16,10 +16,10 @@ from typing import Any, Callable, Dict, List, Optional, Type, Union from urllib.parse import parse_qs, urlparse -import pytest -import requests -import requests_mock -from requests import RequestException +import pytest # type: ignore[import-not-found] +import requests # type: ignore[import-untyped] +import requests_mock # type: ignore[import-not-found] +from requests import RequestException # type: ignore[import-untyped] from databricks.sdk import WorkspaceClient from databricks.sdk.core import Config @@ -113,7 +113,7 @@ def generate_response( # if server actually processed the request (and so changed its state) raise self.exception - custom_response = [self.code, self.body or "", {}] + custom_response = [self.code, self.body or "", {}] # type: ignore[var-annotated] if activate_for_current_invocation: if self.code and 400 <= self.code < 500: @@ -133,18 +133,18 @@ def generate_response( resp = requests.Response() - resp.request = request - resp.status_code = code + resp.request = request # type: ignore[assignment] + resp.status_code = code # type: ignore[assignment] if stream: if type(body_or_stream) != bytes: - resp.raw = io.BytesIO(body_or_stream.encode()) + resp.raw = io.BytesIO(body_or_stream.encode()) # type: ignore[union-attr] else: - resp.raw = io.BytesIO(body_or_stream) + resp.raw = io.BytesIO(body_or_stream) # type: ignore[arg-type] else: - resp._content = body_or_stream.encode() + resp._content = body_or_stream.encode() # type: ignore[union-attr] - for key in headers: - resp.headers[key] = headers[key] + for key in headers: # type: ignore[union-attr] + resp.headers[key] = headers[key] # type: ignore[index] return resp @@ -205,8 +205,8 @@ def run(self, config: Config, monkeypatch) -> None: config = config.copy() config.disable_experimental_files_api_client = not self.enable_new_client config.files_ext_client_download_max_total_recovers = self.max_recovers_total - config.files_ext_client_download_max_total_recovers_without_progressing = self.max_recovers_without_progressing - config.enable_presigned_download_api = False + config.files_ext_client_download_max_total_recovers_without_progressing = self.max_recovers_without_progressing # type: ignore[assignment] + config.enable_presigned_download_api = False # type: ignore[attr-defined] w = WorkspaceClient(config=config) @@ -216,13 +216,13 @@ def run(self, config: Config, monkeypatch) -> None: if self.download_mode == DownloadMode.STREAM: if self.expected_exception is None: response = w.files.download("/test").contents - actual_content = response.read() + actual_content = response.read() # type: ignore[union-attr] assert len(actual_content) == len(session.content) assert actual_content == session.content else: with pytest.raises(self.expected_exception): response = w.files.download("/test").contents - response.read() + response.read() # type: ignore[union-attr] elif self.download_mode == DownloadMode.FILE: # FILE mode with NamedTemporaryFile(delete=False) as temp_file: file_path = temp_file.name @@ -291,7 +291,7 @@ def request( allow_redirects: bool = True, proxies=None, hooks=None, - stream: bool = None, + stream: bool = None, # type: ignore[assignment] verify=None, cert=None, json=None, @@ -309,7 +309,7 @@ def _handle_head_file(self, headers: Dict[str, str], url: str) -> "MockFilesApiD if "If-Unmodified-Since" in headers: assert headers["If-Unmodified-Since"] == self.last_modified resp = MockFilesApiDownloadResponse(self, 0, None, MockFilesApiDownloadRequest(url)) - resp.content = "" + resp.content = "" # type: ignore[attr-defined] return resp def _handle_get_file(self, headers: Dict[str, str], url: str) -> "MockFilesApiDownloadResponse": @@ -345,7 +345,7 @@ class MockFilesApiDownloadRequest: def __init__(self, url: str): self.url = url self.method = "GET" - self.headers = dict() + self.headers = dict() # type: ignore[var-annotated] self.body = None @@ -368,8 +368,8 @@ def __init__( self.headers["Content-Length"] = ( len(session.content) if end_byte_offset is None else end_byte_offset + 1 ) - offset - self.headers["Content-Type"] = "application/octet-stream" - self.headers["Last-Modified"] = session.last_modified + self.headers["Content-Type"] = "application/octet-stream" # type: ignore[assignment] + self.headers["Last-Modified"] = session.last_modified # type: ignore[assignment] self.ok = True self.url = request.url @@ -765,7 +765,7 @@ def get_header(self, request: requests.Request) -> requests.Response: resp = requests.Response() resp.status_code = 200 resp._content = b"" - resp.request = request + resp.request = request # type: ignore[assignment] resp.headers["Content-Length"] = str(self.file_size) resp.headers["Content-Type"] = "application/octet-stream" resp.headers["Last-Modified"] = self.last_modified @@ -917,13 +917,13 @@ def match_request_to_response( and request_url.path == "/api/2.0/fs/create-download-url" ): assert "path" in request_query, "Expected 'path' in query parameters" - file_path = request_query.get("path")[0] + file_path = request_query.get("path")[0] # type: ignore[index] def processor() -> list: url = server_state.get_presigned_url(file_path) return [200, json.dumps({"url": url, "headers": {}}), {}] - return self.custom_response_create_presigned_url.generate_response(request, processor) + return self.custom_response_create_presigned_url.generate_response(request, processor) # type: ignore[union-attr] # Get files status request elif ( @@ -936,7 +936,7 @@ def processor() -> list: resp = server_state.get_header(request) return [resp.status_code, resp._content, resp.headers] - return self.custom_response_get_file_status_api.generate_response(request, processor, stream=True) + return self.custom_response_get_file_status_api.generate_response(request, processor, stream=True) # type: ignore[union-attr] # Direct Files API download request elif ( @@ -949,7 +949,7 @@ def processor() -> list: resp = server_state.get_content(request, api_used="files_api") return [resp.status_code, resp._content, resp.headers] - return self.custom_response_download_from_files_api.generate_response(request, processor, stream=True) + return self.custom_response_download_from_files_api.generate_response(request, processor, stream=True) # type: ignore[union-attr] # Download from Presigned URL request elif request_url.hostname == PresignedUrlDownloadServerState.HOSTNAME and request.method == "GET": @@ -959,7 +959,7 @@ def processor() -> list: resp = server_state.get_content(request, api_used="presigned_url") return [resp.status_code, resp._content, resp.headers] - return self.custom_response_download_from_url.generate_response(request, processor, stream=True) + return self.custom_response_download_from_url.generate_response(request, processor, stream=True) # type: ignore[union-attr] else: raise RuntimeError("Unexpected request " + str(request)) @@ -969,12 +969,12 @@ def run_one_case(self, config: Config, monkeypatch, download_mode: DownloadMode, logger.debug("Parallel download is not supported on Windows. Falling back to sequential download.") return config = config.copy() - config.enable_presigned_download_api = True + config.enable_presigned_download_api = True # type: ignore[attr-defined] config._clock = FakeClock() if self.parallel_download_min_file_size is not None: config.files_ext_parallel_download_min_file_size = self.parallel_download_min_file_size if self.parallel_upload_part_size is not None: - config.files_ext_parallel_upload_part_size = self.parallel_upload_part_size + config.files_ext_parallel_upload_part_size = self.parallel_upload_part_size # type: ignore[attr-defined] w = WorkspaceClient(config=config) state = PresignedUrlDownloadServerState(self.file_size, self.last_modified) @@ -994,7 +994,7 @@ def custom_matcher(request: requests.Request) -> Optional[requests.Response]: else: download_resp = w.files.download(PresignedUrlDownloadTestCase._FILE_PATH) assert download_resp.content_length == self.file_size - assert download_resp.contents.read() == state.content + assert download_resp.contents.read() == state.content # type: ignore[union-attr] if self.expected_download_api is not None: assert state.api_used == self.expected_download_api elif download_mode == DownloadMode.FILE: @@ -1157,7 +1157,7 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: if not isinstance(other, FileContent): - return NotImplemented + return NotImplemented # type: ignore[return-value] return self._length == other._length and self.checksum == other.checksum @@ -1168,8 +1168,8 @@ class MultipartUploadServerState: abort_upload_url_prefix = "https://cloud_provider.com/abort-upload/" def __init__(self, expected_part_size: Optional[int] = None): - self.issued_multipart_urls = {} # part_number -> expiration_time - self.uploaded_parts = {} # part_number -> [part file path, etag] + self.issued_multipart_urls = {} # type: ignore[var-annotated] # part_number -> expiration_time + self.uploaded_parts = {} # type: ignore[var-annotated] # part_number -> [part file path, etag] self.session_token = "token-" + MultipartUploadServerState.randomstr() self.file_content = None self.issued_abort_url_expire_time = None @@ -1185,7 +1185,7 @@ def create_upload_part_url(self, path: str, part_number: int, expire_time: datet def create_abort_url(self, path: str, expire_time: datetime) -> str: assert not self.aborted - self.issued_abort_url_expire_time = expire_time + self.issued_abort_url_expire_time = expire_time # type: ignore[assignment] return f"{self.abort_upload_url_prefix}{path}" def save_part(self, part_number: int, part_content: bytes, etag: str) -> None: @@ -1241,7 +1241,7 @@ def upload_complete(self, etags: dict) -> None: part_content = f.read() sha256.update(part_content) - self.file_content = FileContent(size, sha256.hexdigest()) + self.file_content = FileContent(size, sha256.hexdigest()) # type: ignore[assignment] def abort_upload(self) -> None: self.aborted = True @@ -1307,7 +1307,7 @@ def __init__( self.expected_single_shot_upload = expected_single_shot_upload self.path = "/test.txt" - self.created_temp_files = [] + self.created_temp_files = [] # type: ignore[var-annotated] def customize_config(self, config: Config) -> None: pass @@ -1362,14 +1362,14 @@ def run_one_case(self, config: Config, use_parallel: bool, source_type: "UploadS if self.sdk_retry_timeout_seconds: config.retry_timeout_seconds = self.sdk_retry_timeout_seconds if self.multipart_upload_part_size: - config.multipart_upload_part_size = self.multipart_upload_part_size + config.multipart_upload_part_size = self.multipart_upload_part_size # type: ignore[attr-defined] if self.multipart_upload_max_retries: config.files_ext_multipart_upload_max_retries = self.multipart_upload_max_retries config.files_ext_multipart_upload_min_stream_size = self.multipart_upload_min_stream_size pat_token = "some_pat_token" - config._header_factory = lambda: {"Authorization": f"Bearer {pat_token}"} + config._header_factory = lambda: {"Authorization": f"Bearer {pat_token}"} # type: ignore[assignment] self.customize_config(config) @@ -1394,7 +1394,7 @@ def custom_matcher(request: requests.Request) -> Optional[requests.Response]: ): def processor() -> list: - body = request.body.read() + body = request.body.read() # type: ignore[attr-defined] single_shot_server_state.upload(body) return [200, "", {}] @@ -1409,7 +1409,7 @@ def upload() -> None: if source_type == UploadSourceType.FILE: w.files.upload_from( file_path=self.path, - source_path=content_or_source, + source_path=content_or_source, # type: ignore[arg-type] overwrite=self.overwrite, part_size=self.multipart_upload_part_size, use_parallel=use_parallel, @@ -1418,7 +1418,7 @@ def upload() -> None: else: w.files.upload( file_path=self.path, - contents=content_or_source, + contents=content_or_source, # type: ignore[arg-type] overwrite=self.overwrite, part_size=self.multipart_upload_part_size, use_parallel=use_parallel, @@ -1625,7 +1625,7 @@ def clear_state(self) -> None: self.custom_response_on_abort.clear_state() def match_request_to_response( - self, request: requests.Request, server_state: MultipartUploadServerState + self, request: requests.Request, server_state: MultipartUploadServerState # type: ignore[override] ) -> Optional[requests.Response]: request_url = urlparse(request.url) request_query = parse_qs(request_url.query) @@ -1639,7 +1639,7 @@ def match_request_to_response( ): assert UploadTestCase.is_auth_header_present(request) - assert request.text is None + assert request.text is None # type: ignore[attr-defined] def processor() -> list: response_json = {"multipart_upload": {"session_token": server_state.session_token}} @@ -1694,7 +1694,7 @@ def processor() -> list: assert url_path[: -len(part_num) - 1] == self.path def processor() -> list: - body = request.body.read() + body = request.body.read() # type: ignore[attr-defined] etag = "etag-" + MultipartUploadServerState.randomstr() server_state.save_part(int(part_num), body, etag) return [200, "", {"ETag": etag}] @@ -1762,9 +1762,9 @@ def processor() -> list: and request.method == "PUT" ): assert MultipartUploadTestCase.is_auth_header_present(request) - assert request.content is not None + assert request.content is not None # type: ignore[attr-defined] - def processor(): + def processor(): # type: ignore[misc] server_state.file_content = FileContent.from_bytes(request.content) return [200, "", {}] @@ -2269,7 +2269,7 @@ class ResumableUploadServerState: def __init__(self, unconfirmed_delta: Union[int, list], expected_part_size: Optional[int]): self.unconfirmed_delta = unconfirmed_delta self.confirmed_last_byte: Optional[int] = None # inclusive - self.uploaded_parts = [] + self.uploaded_parts = [] # type: ignore[var-annotated] self.session_token = "token-" + MultipartUploadServerState.randomstr() self.file_content: Optional[FileContent] = None self.aborted = False @@ -2372,7 +2372,7 @@ def __init__( self, name: str, stream_size: int, - cloud: Cloud = None, + cloud: Cloud = None, # type: ignore[assignment] overwrite: bool = True, source_type: Optional[List[UploadSourceType]] = None, use_parallel: Optional[List[bool]] = None, @@ -2439,7 +2439,7 @@ def clear_state(self) -> None: self.custom_response_on_abort.clear_state() def match_request_to_response( - self, request: requests.Request, server_state: ResumableUploadServerState + self, request: requests.Request, server_state: ResumableUploadServerState # type: ignore[override] ) -> Optional[requests.Response]: request_url = urlparse(request.url) request_query = parse_qs(request_url.query) @@ -2453,7 +2453,7 @@ def match_request_to_response( ): assert UploadTestCase.is_auth_header_present(request) - assert request.text is None + assert request.text is None # type: ignore[attr-defined] def processor() -> list: response_json = {"resumable_upload": {"session_token": server_state.session_token}} @@ -2499,17 +2499,17 @@ def processor() -> list: content_range_header = request.headers["Content-range"] is_status_check_request = re.match("bytes \\*/\\*", content_range_header) if is_status_check_request: - assert not request.body + assert not request.body # type: ignore[attr-defined] response_customizer = self.custom_response_on_status_check else: response_customizer = self.custom_response_on_upload def processor() -> list: if not is_status_check_request: - body = request.body.read() + body = request.body.read() # type: ignore[attr-defined] match = re.match("bytes (\\d+)-(\\d+)/(.+)", content_range_header) - [range_start_s, range_end_s, file_size_s] = match.groups() + [range_start_s, range_end_s, file_size_s] = match.groups() # type: ignore[union-attr] server_state.save_part(int(range_start_s), int(range_end_s), body, file_size_s) diff --git a/tests/test_files_utils.py b/tests/test_files_utils.py index d11d91dbc..97deb3b6c 100644 --- a/tests/test_files_utils.py +++ b/tests/test_files_utils.py @@ -4,7 +4,7 @@ from io import BytesIO, RawIOBase, UnsupportedOperation from typing import BinaryIO, Callable, List, Optional, Tuple -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk.mixins.files_utils import (_ConcatenatedInputStream, _PresignedUrlDistributor) @@ -35,10 +35,10 @@ def parse_range_header(range_header: str, content_length: Optional[int] = None) if end is not None and start > end: raise ValueError(f"Start byte {start} is greater than end byte {end}") - return start, end + return start, end # type: ignore[return-value] -class NonSeekableBuffer(RawIOBase, BinaryIO): +class NonSeekableBuffer(RawIOBase, BinaryIO): # type: ignore[misc] """ A non-seekable buffer that wraps a bytes object. Used for unit tests only. This class implements the BinaryIO interface but does not support seeking. @@ -51,7 +51,7 @@ def __init__(self, data: bytes): def read(self, size: int = -1) -> bytes: return self._stream.read(size) - def readline(self, size: int = -1) -> bytes: + def readline(self, size: int = -1) -> bytes: # type: ignore[override] return self._stream.readline(size) def readlines(self, size: int = -1) -> List[bytes]: @@ -77,7 +77,7 @@ def generate(self) -> Tuple[bytes, BinaryIO]: pass -class ConcatenatedInputStreamTestCase(ConcatenatedInputStreamTestCase): +class ConcatenatedInputStreamTestCase(ConcatenatedInputStreamTestCase): # type: ignore[no-redef] def __init__(self, head: bytes, tail: bytes, is_seekable: bool = True): self._head = head self._tail = tail @@ -108,34 +108,34 @@ def to_string(test_case) -> str: test_cases = [ - ConcatenatedInputStreamTestCase(b"", b"zzzz"), - ConcatenatedInputStreamTestCase(b"", b""), - ConcatenatedInputStreamTestCase(b"", b"", is_seekable=False), - ConcatenatedInputStreamTestCase(b"foo", b"bar"), - ConcatenatedInputStreamTestCase(b"foo", b"bar", is_seekable=False), - ConcatenatedInputStreamTestCase(b"", b"zzzz", is_seekable=False), - ConcatenatedInputStreamTestCase(b"non_empty", b""), - ConcatenatedInputStreamTestCase(b"non_empty", b"", is_seekable=False), - ConcatenatedInputStreamTestCase(b"\n\n\n", b"\n\n"), - ConcatenatedInputStreamTestCase(b"\n\n\n", b"\n\n", is_seekable=False), - ConcatenatedInputStreamTestCase(b"aa\nbb\nccc\n", b"dd\nee\nff"), - ConcatenatedInputStreamTestCase(b"aa\nbb\nccc\n", b"dd\nee\nff", is_seekable=False), - ConcatenatedInputStreamTestCase(b"First line\nsecond line", b"first line with line \nbreak"), - ConcatenatedInputStreamTestCase(b"First line\nsecond line", b"first line with line \nbreak", is_seekable=False), - ConcatenatedInputStreamTestCase(b"First line\n", b"\nsecond line"), - ConcatenatedInputStreamTestCase(b"First line\n", b"\nsecond line", is_seekable=False), - ConcatenatedInputStreamTestCase(b"First line\n", b"\n"), - ConcatenatedInputStreamTestCase(b"First line\n", b"\n", is_seekable=False), - ConcatenatedInputStreamTestCase(b"First line\n", b""), - ConcatenatedInputStreamTestCase(b"First line\n", b"", is_seekable=False), - ConcatenatedInputStreamTestCase(b"", b"\nA line"), - ConcatenatedInputStreamTestCase(b"", b"\nA line", is_seekable=False), - ConcatenatedInputStreamTestCase(b"\n", b"\nA line"), - ConcatenatedInputStreamTestCase(b"\n", b"\nA line", is_seekable=False), + ConcatenatedInputStreamTestCase(b"", b"zzzz"), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"", b""), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"", b"", is_seekable=False), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"foo", b"bar"), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"foo", b"bar", is_seekable=False), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"", b"zzzz", is_seekable=False), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"non_empty", b""), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"non_empty", b"", is_seekable=False), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"\n\n\n", b"\n\n"), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"\n\n\n", b"\n\n", is_seekable=False), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"aa\nbb\nccc\n", b"dd\nee\nff"), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"aa\nbb\nccc\n", b"dd\nee\nff", is_seekable=False), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"First line\nsecond line", b"first line with line \nbreak"), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"First line\nsecond line", b"first line with line \nbreak", is_seekable=False), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"First line\n", b"\nsecond line"), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"First line\n", b"\nsecond line", is_seekable=False), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"First line\n", b"\n"), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"First line\n", b"\n", is_seekable=False), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"First line\n", b""), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"First line\n", b"", is_seekable=False), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"", b"\nA line"), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"", b"\nA line", is_seekable=False), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"\n", b"\nA line"), # type: ignore[abstract, call-arg] + ConcatenatedInputStreamTestCase(b"\n", b"\nA line", is_seekable=False), # type: ignore[abstract, call-arg] ] -def verify(test_case: ConcatenatedInputStreamTestCase, apply: Callable[[BinaryIO], Tuple[any, bool]]): +def verify(test_case: ConcatenatedInputStreamTestCase, apply: Callable[[BinaryIO], Tuple[any, bool]]): # type: ignore[valid-type] """ This method applies given function iteratively to both implementation under test and reference implementation of the stream, and verifies the result on each step is identical. @@ -167,7 +167,7 @@ def verify_eof(buffer: BinaryIO): assert len(buffer.readlines(100)) == 0 -@pytest.mark.parametrize("test_case", test_cases, ids=ConcatenatedInputStreamTestCase.to_string) +@pytest.mark.parametrize("test_case", test_cases, ids=ConcatenatedInputStreamTestCase.to_string) # type: ignore[attr-defined] @pytest.mark.parametrize("limit", [-1, 0, 1, 3, 4, 5, 6, 10, 100, 1000]) def test_read(config, test_case: ConcatenatedInputStreamTestCase, limit: int): def apply(buffer: BinaryIO): @@ -182,7 +182,7 @@ def apply(buffer: BinaryIO): verify(test_case, apply) -@pytest.mark.parametrize("test_case", test_cases, ids=ConcatenatedInputStreamTestCase.to_string) +@pytest.mark.parametrize("test_case", test_cases, ids=ConcatenatedInputStreamTestCase.to_string) # type: ignore[attr-defined] @pytest.mark.parametrize("limit", [-1, 0, 1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 100, 1000]) def test_read_line(config, test_case: ConcatenatedInputStreamTestCase, limit: int): def apply(buffer: BinaryIO): @@ -193,7 +193,7 @@ def apply(buffer: BinaryIO): verify(test_case, apply) -@pytest.mark.parametrize("test_case", test_cases, ids=ConcatenatedInputStreamTestCase.to_string) +@pytest.mark.parametrize("test_case", test_cases, ids=ConcatenatedInputStreamTestCase.to_string) # type: ignore[attr-defined] @pytest.mark.parametrize("limit", [-1, 0, 1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 100, 1000]) def test_read_lines(config, test_case: ConcatenatedInputStreamTestCase, limit: int): def apply(buffer: BinaryIO): @@ -204,7 +204,7 @@ def apply(buffer: BinaryIO): verify(test_case, apply) -@pytest.mark.parametrize("test_case", test_cases, ids=ConcatenatedInputStreamTestCase.to_string) +@pytest.mark.parametrize("test_case", test_cases, ids=ConcatenatedInputStreamTestCase.to_string) # type: ignore[attr-defined] def test_iterator(config, test_case: ConcatenatedInputStreamTestCase): def apply(buffer: BinaryIO): try: @@ -216,11 +216,11 @@ def apply(buffer: BinaryIO): verify(test_case, apply) -def seeks_to_string(seeks: [Tuple[int, int]]): +def seeks_to_string(seeks: [Tuple[int, int]]): # type: ignore[valid-type] ", ".join(list(map(lambda seek: f"Seek: offset={seek[0]}, whence={seek[1]}", seeks))) -@pytest.mark.parametrize("test_case", test_cases, ids=ConcatenatedInputStreamTestCase.to_string) +@pytest.mark.parametrize("test_case", test_cases, ids=ConcatenatedInputStreamTestCase.to_string) # type: ignore[attr-defined] @pytest.mark.parametrize( "seeks", [ @@ -245,7 +245,7 @@ def read_and_restore(buf: BinaryIO) -> bytes: buf.seek(pos) return result - def safe_call(buf: BinaryIO, call: Callable[[BinaryIO], any]) -> (any, bool): + def safe_call(buf: BinaryIO, call: Callable[[BinaryIO], any]) -> (any, bool): # type: ignore[syntax, valid-type] """ Calls the provided function on the buffer and returns the result. It is a wrapper to handle exceptions gracefully. diff --git a/tests/test_internal.py b/tests/test_internal.py index b0417ec64..83a7ad56d 100644 --- a/tests/test_internal.py +++ b/tests/test_internal.py @@ -1,9 +1,9 @@ from dataclasses import dataclass from enum import Enum -import pytest -from google.protobuf.duration_pb2 import Duration -from google.protobuf.timestamp_pb2 import Timestamp +import pytest # type: ignore[import-not-found] +from google.protobuf.duration_pb2 import Duration # type: ignore[import-untyped] +from google.protobuf.timestamp_pb2 import Timestamp # type: ignore[import-untyped] from databricks.sdk.common.types.fieldmask import FieldMask from databricks.sdk.service._internal import ( diff --git a/tests/test_metadata_service_auth.py b/tests/test_metadata_service_auth.py index e293a0b4b..1c7fa110b 100644 --- a/tests/test_metadata_service_auth.py +++ b/tests/test_metadata_service_auth.py @@ -1,7 +1,7 @@ import json from datetime import datetime, timedelta -import requests +import requests # type: ignore[import-untyped] from databricks.sdk.core import Config from databricks.sdk.credentials_provider import MetadataServiceTokenSource diff --git a/tests/test_model_serving_auth.py b/tests/test_model_serving_auth.py index 3c3ddfa99..60cee2e1b 100644 --- a/tests/test_model_serving_auth.py +++ b/tests/test_model_serving_auth.py @@ -1,7 +1,7 @@ import threading import time -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk.core import Config from databricks.sdk.credentials_provider import ModelServingUserCredentials diff --git a/tests/test_oidc.py b/tests/test_oidc.py index 4bed32e96..6e2e2046a 100644 --- a/tests/test_oidc.py +++ b/tests/test_oidc.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import Optional, Tuple -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk import oidc @@ -11,8 +11,8 @@ class EnvTestCase: name: str env_name: str = "" env_value: str = "" - want: oidc.IdToken = None - wantException: Exception = None + want: oidc.IdToken = None # type: ignore[assignment] + wantException: Exception = None # type: ignore[assignment] _env_id_test_cases = [ @@ -26,13 +26,13 @@ class EnvTestCase: name="missing_env_var", env_name="OIDC_TEST_TOKEN_MISSING", env_value="", - wantException=ValueError, + wantException=ValueError, # type: ignore[arg-type] ), EnvTestCase( name="empty_env_var", env_name="OIDC_TEST_TOKEN_EMPTY", env_value="", - wantException=ValueError, + wantException=ValueError, # type: ignore[arg-type] ), EnvTestCase( name="different_variable_name", @@ -60,8 +60,8 @@ class FileTestCase: name: str file: Optional[Tuple[str, str]] = None # (name, content) filepath: str = "" - want: oidc.IdToken = None - wantException: Exception = None + want: oidc.IdToken = None # type: ignore[assignment] + wantException: Exception = None # type: ignore[assignment] _file_id_test_cases = [ @@ -69,18 +69,18 @@ class FileTestCase: name="missing_filepath", file=("token", "content"), filepath="", - wantException=ValueError, + wantException=ValueError, # type: ignore[arg-type] ), FileTestCase( name="empty_file", file=("token", ""), filepath="token", - wantException=ValueError, + wantException=ValueError, # type: ignore[arg-type] ), FileTestCase( name="file_does_not_exist", filepath="nonexistent-file", - wantException=ValueError, + wantException=ValueError, # type: ignore[arg-type] ), FileTestCase( name="file_exists", diff --git a/tests/test_oidc_token_supplier.py b/tests/test_oidc_token_supplier.py index 57109c37b..62666325d 100644 --- a/tests/test_oidc_token_supplier.py +++ b/tests/test_oidc_token_supplier.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import Dict, Optional -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk.oidc_token_supplier import AzureDevOpsOIDCTokenSupplier diff --git a/tests/test_open_ai_mixin.py b/tests/test_open_ai_mixin.py index dfc248d0a..27844fbb3 100644 --- a/tests/test_open_ai_mixin.py +++ b/tests/test_open_ai_mixin.py @@ -1,7 +1,7 @@ import sys from io import BytesIO -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk.core import Config from databricks.sdk.service.serving import ExternalFunctionRequestHttpMethod diff --git a/tests/test_refreshable.py b/tests/test_refreshable.py index dc3157331..485d46737 100644 --- a/tests/test_refreshable.py +++ b/tests/test_refreshable.py @@ -13,14 +13,14 @@ def __init__( disable_async, token=None, stale_duration=timedelta(seconds=60), - refresh_effect: Callable[[], Token] = None, + refresh_effect: Callable[[], Token] = None, # type: ignore[assignment] ): super().__init__(token, disable_async, stale_duration) self._refresh_effect = refresh_effect self._refresh_count = 0 def refresh(self) -> Token: - if self._refresh_effect: + if self._refresh_effect: # type: ignore[truthy-function] self._token = self._refresh_effect() self._refresh_count += 1 return self._token @@ -41,7 +41,7 @@ def f() -> Token: def blocking_refresh( token: Token, -) -> (Callable[[], Token], Callable[[], None]): +) -> (Callable[[], Token], Callable[[], None]): # type: ignore[syntax] """ Create a refresh function that blocks until unblock is called. diff --git a/tests/test_retries.py b/tests/test_retries.py index 3fc97114d..0583287d4 100644 --- a/tests/test_retries.py +++ b/tests/test_retries.py @@ -1,7 +1,7 @@ from datetime import timedelta from typing import Any, Literal, Optional, Tuple, Type -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk.errors import NotFound, ResourceDoesNotExist from databricks.sdk.retries import RetryError, poll, retried @@ -257,7 +257,7 @@ def fn() -> Tuple[Any, Optional[RetryError]]: call_count += 1 if scenario == "success": - if call_count < attempts: + if call_count < attempts: # type: ignore[operator] return None, RetryError.continues(f"attempt {call_count}") return result_value, None @@ -265,12 +265,12 @@ def fn() -> Tuple[Any, Optional[RetryError]]: return None, RetryError.continues("retrying") elif scenario == "halt": - if call_count < attempts: + if call_count < attempts: # type: ignore[operator] return None, RetryError.continues("retrying") return None, RetryError.halt(ValueError(exception_msg)) elif scenario == "unexpected": - if call_count < attempts: + if call_count < attempts: # type: ignore[operator] return None, RetryError.continues("retrying") raise RuntimeError(exception_msg) @@ -286,10 +286,10 @@ def fn() -> Tuple[Any, Optional[RetryError]]: with pytest.raises(exception_type) as exc_info: poll(fn, timeout=timedelta(seconds=timeout), clock=clock) - assert exception_msg in str(exc_info.value) + assert exception_msg in str(exc_info.value) # type: ignore[operator] assert call_count >= 1 if scenario == "timeout": - assert clock.time() >= min_time + assert clock.time() >= min_time # type: ignore[operator] elif scenario in ("halt", "unexpected"): assert call_count == attempts diff --git a/tests/test_user_agent.py b/tests/test_user_agent.py index c9c6889c4..1fb70da37 100644 --- a/tests/test_user_agent.py +++ b/tests/test_user_agent.py @@ -1,6 +1,6 @@ import os -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk.version import __version__ diff --git a/tests/testdata/test_casing.py b/tests/testdata/test_casing.py index ef6257034..f2d2ec68d 100644 --- a/tests/testdata/test_casing.py +++ b/tests/testdata/test_casing.py @@ -1,4 +1,4 @@ -import pytest +import pytest # type: ignore[import-not-found] from databricks.sdk.casing import Casing