diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index 141e7e44d..27607da1a 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -5,8 +5,7 @@ import databricks.sdk.core as client import databricks.sdk.dbutils as dbutils -import databricks.sdk.service as service -from databricks.sdk import azure +from databricks.sdk import azure, service from databricks.sdk.credentials_provider import CredentialsStrategy from databricks.sdk.mixins.compute import ClustersExt from databricks.sdk.mixins.files import DbfsExt, FilesExt diff --git a/databricks/sdk/_base_client.py b/databricks/sdk/_base_client.py index 90570de97..5af684cb6 100644 --- a/databricks/sdk/_base_client.py +++ b/databricks/sdk/_base_client.py @@ -43,17 +43,17 @@ class _BaseClient: def __init__( self, - debug_truncate_bytes: int = None, - retry_timeout_seconds: int = None, - user_agent_base: str = None, - header_factory: Callable[[], dict] = None, - max_connection_pools: int = None, - max_connections_per_pool: int = None, - pool_block: bool = True, - http_timeout_seconds: float = None, - extra_error_customizers: List[_ErrorCustomizer] = None, - debug_headers: bool = False, - clock: Clock = None, + debug_truncate_bytes: Optional[int] = None, + retry_timeout_seconds: Optional[int] = None, + user_agent_base: Optional[str] = None, + header_factory: Optional[Callable[[], dict]] = None, + max_connection_pools: Optional[int] = None, + max_connections_per_pool: Optional[int] = None, + pool_block: Optional[bool] = True, + http_timeout_seconds: Optional[float] = None, + extra_error_customizers: Optional[List[_ErrorCustomizer]] = None, + debug_headers: Optional[bool] = False, + clock: Optional[Clock] = None, streaming_buffer_size: int = 1024 * 1024, ): # 1MB """ @@ -148,14 +148,14 @@ def do( self, method: str, url: str, - query: dict = None, - headers: dict = None, - body: dict = None, + query: Optional[dict] = None, + headers: Optional[dict] = None, + body: Optional[dict] = None, raw: bool = False, files=None, data=None, - auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None, - response_headers: List[str] = None, + auth: Optional[Callable[[requests.PreparedRequest], requests.PreparedRequest]] = None, + response_headers: Optional[List[str]] = None, ) -> Union[dict, list, BinaryIO]: if headers is None: headers = {} @@ -272,9 +272,9 @@ def _perform( self, method: str, url: str, - query: dict = None, - headers: dict = None, - body: dict = None, + query: Optional[dict] = None, + headers: Optional[dict] = None, + body: Optional[dict] = None, raw: bool = False, files=None, data=None, @@ -325,10 +325,10 @@ class _StreamingResponse(BinaryIO): _closed: bool = False def fileno(self) -> int: - pass + return 0 - def flush(self) -> int: - pass + def flush(self) -> int: # type: ignore + return 0 def __init__(self, response: _RawResponse, chunk_size: Union[int, None] = None): self._response = response @@ -403,10 +403,10 @@ def truncate(self, __size: Union[int, None] = ...) -> int: def writable(self) -> bool: return False - def write(self, s: Union[bytes, bytearray]) -> int: + def write(self, s: Union[bytes, bytearray]) -> int: # type: ignore raise NotImplementedError() - def writelines(self, lines: Iterable[bytes]) -> None: + def writelines(self, lines: Iterable[bytes]) -> None: # type: ignore raise NotImplementedError() def __next__(self) -> bytes: diff --git a/databricks/sdk/_widgets/ipywidgets_utils.py b/databricks/sdk/_widgets/ipywidgets_utils.py index 6e002562e..3caff486d 100644 --- a/databricks/sdk/_widgets/ipywidgets_utils.py +++ b/databricks/sdk/_widgets/ipywidgets_utils.py @@ -30,7 +30,7 @@ def value(self): if type(value) == list or type(value) == tuple: return ",".join(value) - raise ValueError("The returned value has invalid type (" + type(value) + ").") + raise ValueError(f"The returned value has invalid type ({type(value)}).") class IPyWidgetUtil(WidgetUtils): diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index 203e84e6c..92e3dbf89 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -66,16 +66,16 @@ def get_oauth_token(self, auth_details: str) -> Token: def do( self, method: str, - path: str = None, - url: str = None, - query: dict = None, - headers: dict = None, - body: dict = None, + path: Optional[str] = None, + url: Optional[str] = None, + query: Optional[dict] = None, + headers: Optional[dict] = None, + body: Optional[dict] = None, raw: bool = False, files=None, data=None, - auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None, - response_headers: List[str] = None, + auth: Optional[Callable[[requests.PreparedRequest], requests.PreparedRequest]] = None, + response_headers: Optional[List[str]] = None, ) -> Union[dict, list, BinaryIO]: if url is None: # Remove extra `/` from path for Files API diff --git a/databricks/sdk/environments.py b/databricks/sdk/environments.py index 170767a09..d14393329 100644 --- a/databricks/sdk/environments.py +++ b/databricks/sdk/environments.py @@ -113,7 +113,7 @@ def azure_active_directory_endpoint(self) -> Optional[str]: ] -def get_environment_for_hostname(hostname: str) -> DatabricksEnvironment: +def get_environment_for_hostname(hostname: Optional[str]) -> DatabricksEnvironment: if not hostname: return DEFAULT_ENVIRONMENT for env in ALL_ENVS: diff --git a/databricks/sdk/errors/base.py b/databricks/sdk/errors/base.py index 908172576..4e104c0b7 100644 --- a/databricks/sdk/errors/base.py +++ b/databricks/sdk/errors/base.py @@ -1,7 +1,7 @@ import re import warnings from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import requests @@ -12,10 +12,10 @@ class ErrorDetail: def __init__( self, - type: str = None, - reason: str = None, - domain: str = None, - metadata: dict = None, + type: Optional[str] = None, + reason: Optional[str] = None, + domain: Optional[str] = None, + metadata: Optional[dict] = None, **kwargs, ): self.type = type @@ -24,7 +24,7 @@ def __init__( self.metadata = metadata @classmethod - def from_dict(cls, d: Dict[str, any]) -> "ErrorDetail": + def from_dict(cls, d: Dict[str, Any]) -> "ErrorDetail": # Key "@type" is not a valid keyword argument name in Python. Rename # it to "type" to avoid conflicts. safe_args = {} @@ -39,15 +39,15 @@ class DatabricksError(IOError): def __init__( self, - message: str = None, + message: Optional[str] = None, *, - error_code: str = None, - detail: str = None, - status: str = None, - scimType: str = None, - error: str = None, - retry_after_secs: int = None, - details: List[Dict[str, any]] = None, + error_code: Optional[str] = None, + detail: Optional[str] = None, + status: Optional[str] = None, + scimType: Optional[str] = None, + error: Optional[str] = None, + retry_after_secs: Optional[int] = None, + details: Optional[List[Dict[str, Any]]] = None, **kwargs, ): """ @@ -102,7 +102,7 @@ def __init__( super().__init__(message if message else error) self.error_code = error_code self.retry_after_secs = retry_after_secs - self._error_details = errdetails.parse_error_details(details) + self._error_details = errdetails.parse_error_details(details or []) self.kwargs = kwargs # Deprecated. diff --git a/databricks/sdk/logger/round_trip_logger.py b/databricks/sdk/logger/round_trip_logger.py index e6ac5e80b..7ff9d55c9 100644 --- a/databricks/sdk/logger/round_trip_logger.py +++ b/databricks/sdk/logger/round_trip_logger.py @@ -1,6 +1,6 @@ import json import urllib.parse -from typing import Dict, List +from typing import Any, Dict, List import requests @@ -99,7 +99,7 @@ def _recursive_marshal_list(self, s, budget) -> list: budget -= len(str(raw)) return out - def _recursive_marshal(self, v: any, budget: int) -> any: + def _recursive_marshal(self, v: Any, budget: int) -> Any: if isinstance(v, dict): return self._recursive_marshal_dict(v, budget) elif isinstance(v, list): diff --git a/databricks/sdk/mixins/workspace.py b/databricks/sdk/mixins/workspace.py index 7476a4b83..f62ad5bff 100644 --- a/databricks/sdk/mixins/workspace.py +++ b/databricks/sdk/mixins/workspace.py @@ -1,11 +1,11 @@ -from typing import BinaryIO, Iterator, Optional, Union +from typing import Any, BinaryIO, Iterator, Optional, Union from ..core import DatabricksError from ..service.workspace import (ExportFormat, ImportFormat, Language, ObjectInfo, ObjectType, WorkspaceAPI) -def _fqcn(x: any) -> str: +def _fqcn(x: Any) -> str: return f"{x.__module__}.{x.__name__}" diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index 4685efa5c..d2df2f0f5 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -78,9 +78,9 @@ def as_dict(self) -> dict: @dataclass class Token: access_token: str - token_type: str = None - refresh_token: str = None - expiry: datetime = None + token_type: Optional[str] = None + refresh_token: Optional[str] = None + expiry: Optional[datetime] = None @property def expired(self): @@ -238,7 +238,7 @@ def _get_executor(cls): def __init__( self, - token: Token = None, + token: Optional[Token] = None, disable_async: bool = True, stale_duration: timedelta = _DEFAULT_STALE_DURATION, ): @@ -248,7 +248,7 @@ def __init__( # Lock self._lock = threading.Lock() # Non Thread safe properties. They should be accessed only when protected by the lock above. - self._token = token + self._token = token or Token("") self._is_refreshing = False self._refresh_err = False @@ -312,7 +312,7 @@ def _trigger_async_refresh(self): """Starts an asynchronous refresh if none is in progress.""" def _refresh_internal(): - new_token: Token = None + new_token = None try: new_token = self.refresh() except Exception as e: @@ -737,9 +737,9 @@ def __init__( host: str, oidc_endpoints: OidcEndpoints, client_id: str, - redirect_url: str = None, - client_secret: str = None, - scopes: List[str] = None, + redirect_url: Optional[str] = None, + client_secret: Optional[str] = None, + scopes: Optional[List[str]] = None, ) -> None: self._host = host self._client_id = client_id diff --git a/databricks/sdk/retries.py b/databricks/sdk/retries.py index e4408929d..5528a8978 100644 --- a/databricks/sdk/retries.py +++ b/databricks/sdk/retries.py @@ -11,11 +11,11 @@ def retried( *, - on: Sequence[Type[BaseException]] = None, - is_retryable: Callable[[BaseException], Optional[str]] = None, + on: Optional[Sequence[Type[BaseException]]] = None, + is_retryable: Optional[Callable[[BaseException], Optional[str]]] = None, timeout=timedelta(minutes=20), - clock: Clock = None, - before_retry: Callable = None, + clock: Optional[Clock] = None, + before_retry: Optional[Callable] = None, ): has_allowlist = on is not None has_callback = is_retryable is not None diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index c4bfd042b..adf26c707 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -132,7 +132,7 @@ def inner() -> Dict[str, str]: try: # We expect this to fail locally since dbconnect does not support sparkcontext. This is just for typing - sc = spark.sparkContext + sc = spark.sparkContext # type: ignore except Exception as e: logging.debug(f"Failed to initialize global 'sc', continuing. Cause: {e}") diff --git a/pyproject.toml b/pyproject.toml index 82c6be56a..72dab7d59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,4 +75,11 @@ include = ["databricks", "databricks.*"] [tool.black] line-length = 120 -target-version = ['py37', 'py38', 'py39', 'py310', 'py311','py312','py313'] \ No newline at end of file +target-version = ['py37', 'py38', 'py39', 'py310', 'py311','py312','py313'] + +[tool.pyright] +include = ["."] +exclude = ["**/node_modules", "**/__pycache__"] +reportMissingImports = true +reportMissingTypeStubs = false +pythonVersion = "3.7" \ No newline at end of file