Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ jobs:
- name: Fail on differences
run: git diff --exit-code

type-check:
runs-on: ubuntu-latest

steps:
- name: Checkout
uses: actions/checkout@v2

- name: Run mypy type checking
run: make dev mypy

check-manifest:
runs-on: ubuntu-latest

Expand Down
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ lint:
pycodestyle databricks
autoflake --check-diff --quiet --recursive databricks

mypy:
python -m mypy databricks tests

test:
pytest -m 'not integration and not benchmark' --cov=databricks --cov-report html tests

Expand Down
22 changes: 11 additions & 11 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List,
Optional, Type, Union)

import requests
import requests.adapters
import requests # type: ignore[import-untyped]
import requests.adapters # type: ignore[import-untyped]

from . import useragent
from .casing import Casing
Expand Down Expand Up @@ -92,16 +92,16 @@ def __init__(
http_adapter = requests.adapters.HTTPAdapter(
pool_connections=max_connections_per_pool or 20,
pool_maxsize=max_connection_pools or 20,
pool_block=pool_block,
pool_block=pool_block, # type: ignore[arg-type]
)
self._session.mount("https://", http_adapter)

# Default to 60 seconds
self._http_timeout_seconds = http_timeout_seconds or 60

self._error_parser = _Parser(
extra_error_customizers=extra_error_customizers,
debug_headers=debug_headers,
extra_error_customizers=extra_error_customizers, # type: ignore[arg-type]
debug_headers=debug_headers, # type: ignore[arg-type]
)

def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
Expand All @@ -127,7 +127,7 @@ def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]:
# {'filter_by.user_ids': [123, 456]}
# See the following for more information:
# https://cloud.google.com/endpoints/docs/grpc-service-config/reference/rpc/google.api#google.api.HttpRule
def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]:
def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: # type: ignore[misc]
for k1, v1 in d.items():
if isinstance(v1, dict):
v1 = dict(flatten_dict(v1))
Expand Down Expand Up @@ -281,7 +281,7 @@ def _perform(
raw: bool = False,
files=None,
data=None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None,
auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None, # type: ignore[assignment]
):
response = self._session.request(
method,
Expand All @@ -305,7 +305,7 @@ def _perform(
def _record_request_log(self, response: requests.Response, raw: bool = False) -> None:
if not logger.isEnabledFor(logging.DEBUG):
return
logger.debug(RoundTrip(response, self._debug_headers, self._debug_truncate_bytes, raw).generate())
logger.debug(RoundTrip(response, self._debug_headers, self._debug_truncate_bytes, raw).generate()) # type: ignore[arg-type]


class _RawResponse(ABC):
Expand Down Expand Up @@ -343,7 +343,7 @@ def _open(self) -> None:
if self._closed:
raise ValueError("I/O operation on closed file")
if not self._content:
self._content = self._response.iter_content(chunk_size=self._chunk_size, decode_unicode=False)
self._content = self._response.iter_content(chunk_size=self._chunk_size, decode_unicode=False) # type: ignore[arg-type]

def __enter__(self) -> BinaryIO:
self._open()
Expand Down Expand Up @@ -372,7 +372,7 @@ def read(self, n: int = -1) -> bytes:
while remaining_bytes > 0 or read_everything:
if len(self._buffer) == 0:
try:
self._buffer = next(self._content)
self._buffer = next(self._content) # type: ignore[arg-type]
except StopIteration:
break
bytes_available = len(self._buffer)
Expand Down Expand Up @@ -416,7 +416,7 @@ def __next__(self) -> bytes:
return self.read(1)

def __iter__(self) -> Iterator[bytes]:
return self._content
return self._content # type: ignore[return-value]

def __exit__(
self,
Expand Down
4 changes: 2 additions & 2 deletions databricks/sdk/_widgets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _remove_all(self):
# We only use ipywidgets if we are in a notebook interactive shell otherwise we raise error,
# to fallback to using default_widgets. Also, users WILL have IPython in their notebooks (jupyter),
# because we DO NOT SUPPORT any other notebook backends, and hence fallback to default_widgets.
from IPython.core.getipython import get_ipython
from IPython.core.getipython import get_ipython # type: ignore[import-not-found]

# Detect if we are in an interactive notebook by iterating over the mro of the current ipython instance,
# to find ZMQInteractiveShell (jupyter). When used from REPL or file, this check will fail, since the
Expand Down Expand Up @@ -79,5 +79,5 @@ def _remove_all(self):
except:
from .default_widgets_utils import DefaultValueOnlyWidgetUtils

widget_impl = DefaultValueOnlyWidgetUtils
widget_impl = DefaultValueOnlyWidgetUtils # type: ignore[assignment, misc]
logging.debug("Using default_value_only implementation for dbutils.")
4 changes: 2 additions & 2 deletions databricks/sdk/_widgets/ipywidgets_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing

from IPython.core.display_functions import display
from ipywidgets.widgets import (ValueWidget, Widget, widget_box,
from IPython.core.display_functions import display # type: ignore[import-not-found]
from ipywidgets.widgets import (ValueWidget, Widget, widget_box, # type: ignore[import-not-found,import-untyped]
widget_selection, widget_string)

from .default_widgets_utils import WidgetUtils
Expand Down
2 changes: 1 addition & 1 deletion databricks/sdk/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .service.provisioning import Workspace


def add_workspace_id_header(cfg: "Config", headers: Dict[str, str]):
def add_workspace_id_header(cfg: "Config", headers: Dict[str, str]): # type: ignore[name-defined]
if cfg.azure_workspace_resource_id:
headers["X-Databricks-Azure-Workspace-Resource-Id"] = cfg.azure_workspace_resource_id

Expand Down
2 changes: 1 addition & 1 deletion databricks/sdk/casing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ class _Name(object):
def __init__(self, raw_name: str):
#
self._segments = []
segment = []
segment = [] # type: ignore[var-annotated]
for ch in raw_name:
if ch.isupper():
if segment:
Expand Down
104 changes: 52 additions & 52 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import urllib.parse
from typing import Dict, Iterable, List, Optional

import requests
import requests # type: ignore[import-untyped]

from . import useragent
from ._base_client import _fix_host_if_needed
Expand All @@ -28,10 +28,10 @@ class ConfigAttribute:
"""Configuration attribute metadata and descriptor protocols."""

# name and transform are discovered from Config.__new__
name: str = None
name: str = None # type: ignore[assignment]
transform: type = str

def __init__(self, env: str = None, auth: str = None, sensitive: bool = False):
def __init__(self, env: str = None, auth: str = None, sensitive: bool = False): # type: ignore[assignment]
self.env = env
self.auth = auth
self.sensitive = sensitive
Expand All @@ -41,7 +41,7 @@ def __get__(self, cfg: "Config", owner):
return None
return cfg._inner.get(self.name, None)

def __set__(self, cfg: "Config", value: any):
def __set__(self, cfg: "Config", value: any): # type: ignore[valid-type]
cfg._inner[self.name] = self.transform(value)

def __repr__(self) -> str:
Expand All @@ -59,58 +59,58 @@ def with_user_agent_extra(key: str, value: str):


class Config:
host: str = ConfigAttribute(env="DATABRICKS_HOST")
account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID")
host: str = ConfigAttribute(env="DATABRICKS_HOST") # type: ignore[assignment]
account_id: str = ConfigAttribute(env="DATABRICKS_ACCOUNT_ID") # type: ignore[assignment]

# PAT token.
token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True)
token: str = ConfigAttribute(env="DATABRICKS_TOKEN", auth="pat", sensitive=True) # type: ignore[assignment]

# Audience for OIDC ID token source accepting an audience as a parameter.
# For example, the GitHub action ID token source.
token_audience: str = ConfigAttribute(env="DATABRICKS_TOKEN_AUDIENCE", auth="github-oidc")
token_audience: str = ConfigAttribute(env="DATABRICKS_TOKEN_AUDIENCE", auth="github-oidc") # type: ignore[assignment]

# Environment variable for OIDC token.
oidc_token_env: str = ConfigAttribute(env="DATABRICKS_OIDC_TOKEN_ENV", auth="env-oidc")
oidc_token_filepath: str = ConfigAttribute(env="DATABRICKS_OIDC_TOKEN_FILE", auth="file-oidc")

username: str = ConfigAttribute(env="DATABRICKS_USERNAME", auth="basic")
password: str = ConfigAttribute(env="DATABRICKS_PASSWORD", auth="basic", sensitive=True)

client_id: str = ConfigAttribute(env="DATABRICKS_CLIENT_ID", auth="oauth")
client_secret: str = ConfigAttribute(env="DATABRICKS_CLIENT_SECRET", auth="oauth", sensitive=True)
profile: str = ConfigAttribute(env="DATABRICKS_CONFIG_PROFILE")
config_file: str = ConfigAttribute(env="DATABRICKS_CONFIG_FILE")
google_service_account: str = ConfigAttribute(env="DATABRICKS_GOOGLE_SERVICE_ACCOUNT", auth="google")
google_credentials: str = ConfigAttribute(env="GOOGLE_CREDENTIALS", auth="google", sensitive=True)
azure_workspace_resource_id: str = ConfigAttribute(env="DATABRICKS_AZURE_RESOURCE_ID", auth="azure")
azure_use_msi: bool = ConfigAttribute(env="ARM_USE_MSI", auth="azure")
azure_client_secret: str = ConfigAttribute(env="ARM_CLIENT_SECRET", auth="azure", sensitive=True)
azure_client_id: str = ConfigAttribute(env="ARM_CLIENT_ID", auth="azure")
azure_tenant_id: str = ConfigAttribute(env="ARM_TENANT_ID", auth="azure")
azure_environment: str = ConfigAttribute(env="ARM_ENVIRONMENT")
databricks_cli_path: str = ConfigAttribute(env="DATABRICKS_CLI_PATH")
auth_type: str = ConfigAttribute(env="DATABRICKS_AUTH_TYPE")
cluster_id: str = ConfigAttribute(env="DATABRICKS_CLUSTER_ID")
warehouse_id: str = ConfigAttribute(env="DATABRICKS_WAREHOUSE_ID")
serverless_compute_id: str = ConfigAttribute(env="DATABRICKS_SERVERLESS_COMPUTE_ID")
skip_verify: bool = ConfigAttribute()
http_timeout_seconds: float = ConfigAttribute()
debug_truncate_bytes: int = ConfigAttribute(env="DATABRICKS_DEBUG_TRUNCATE_BYTES")
debug_headers: bool = ConfigAttribute(env="DATABRICKS_DEBUG_HEADERS")
rate_limit: int = ConfigAttribute(env="DATABRICKS_RATE_LIMIT")
retry_timeout_seconds: int = ConfigAttribute()
oidc_token_env: str = ConfigAttribute(env="DATABRICKS_OIDC_TOKEN_ENV", auth="env-oidc") # type: ignore[assignment]
oidc_token_filepath: str = ConfigAttribute(env="DATABRICKS_OIDC_TOKEN_FILE", auth="file-oidc") # type: ignore[assignment]

username: str = ConfigAttribute(env="DATABRICKS_USERNAME", auth="basic") # type: ignore[assignment]
password: str = ConfigAttribute(env="DATABRICKS_PASSWORD", auth="basic", sensitive=True) # type: ignore[assignment]

client_id: str = ConfigAttribute(env="DATABRICKS_CLIENT_ID", auth="oauth") # type: ignore[assignment]
client_secret: str = ConfigAttribute(env="DATABRICKS_CLIENT_SECRET", auth="oauth", sensitive=True) # type: ignore[assignment]
profile: str = ConfigAttribute(env="DATABRICKS_CONFIG_PROFILE") # type: ignore[assignment]
config_file: str = ConfigAttribute(env="DATABRICKS_CONFIG_FILE") # type: ignore[assignment]
google_service_account: str = ConfigAttribute(env="DATABRICKS_GOOGLE_SERVICE_ACCOUNT", auth="google") # type: ignore[assignment]
google_credentials: str = ConfigAttribute(env="GOOGLE_CREDENTIALS", auth="google", sensitive=True) # type: ignore[assignment]
azure_workspace_resource_id: str = ConfigAttribute(env="DATABRICKS_AZURE_RESOURCE_ID", auth="azure") # type: ignore[assignment]
azure_use_msi: bool = ConfigAttribute(env="ARM_USE_MSI", auth="azure") # type: ignore[assignment]
azure_client_secret: str = ConfigAttribute(env="ARM_CLIENT_SECRET", auth="azure", sensitive=True) # type: ignore[assignment]
azure_client_id: str = ConfigAttribute(env="ARM_CLIENT_ID", auth="azure") # type: ignore[assignment]
azure_tenant_id: str = ConfigAttribute(env="ARM_TENANT_ID", auth="azure") # type: ignore[assignment]
azure_environment: str = ConfigAttribute(env="ARM_ENVIRONMENT") # type: ignore[assignment]
databricks_cli_path: str = ConfigAttribute(env="DATABRICKS_CLI_PATH") # type: ignore[assignment]
auth_type: str = ConfigAttribute(env="DATABRICKS_AUTH_TYPE") # type: ignore[assignment]
cluster_id: str = ConfigAttribute(env="DATABRICKS_CLUSTER_ID") # type: ignore[assignment]
warehouse_id: str = ConfigAttribute(env="DATABRICKS_WAREHOUSE_ID") # type: ignore[assignment]
serverless_compute_id: str = ConfigAttribute(env="DATABRICKS_SERVERLESS_COMPUTE_ID") # type: ignore[assignment]
skip_verify: bool = ConfigAttribute() # type: ignore[assignment]
http_timeout_seconds: float = ConfigAttribute() # type: ignore[assignment]
debug_truncate_bytes: int = ConfigAttribute(env="DATABRICKS_DEBUG_TRUNCATE_BYTES") # type: ignore[assignment]
debug_headers: bool = ConfigAttribute(env="DATABRICKS_DEBUG_HEADERS") # type: ignore[assignment]
rate_limit: int = ConfigAttribute(env="DATABRICKS_RATE_LIMIT") # type: ignore[assignment]
retry_timeout_seconds: int = ConfigAttribute() # type: ignore[assignment]
metadata_service_url = ConfigAttribute(
env="DATABRICKS_METADATA_SERVICE_URL",
auth="metadata-service",
sensitive=True,
)
max_connection_pools: int = ConfigAttribute()
max_connections_per_pool: int = ConfigAttribute()
max_connection_pools: int = ConfigAttribute() # type: ignore[assignment]
max_connections_per_pool: int = ConfigAttribute() # type: ignore[assignment]
databricks_environment: Optional[DatabricksEnvironment] = None

disable_async_token_refresh: bool = ConfigAttribute(env="DATABRICKS_DISABLE_ASYNC_TOKEN_REFRESH")
disable_async_token_refresh: bool = ConfigAttribute(env="DATABRICKS_DISABLE_ASYNC_TOKEN_REFRESH") # type: ignore[assignment]

disable_experimental_files_api_client: bool = ConfigAttribute(
disable_experimental_files_api_client: bool = ConfigAttribute( # type: ignore[assignment]
env="DATABRICKS_DISABLE_EXPERIMENTAL_FILES_API_CLIENT"
)

Expand Down Expand Up @@ -217,8 +217,8 @@ def __init__(
**kwargs,
):
self._header_factory = None
self._inner = {}
self._user_agent_other_info = []
self._inner = {} # type: ignore[var-annotated]
self._user_agent_other_info = [] # type: ignore[var-annotated]
if credentials_strategy and credentials_provider:
raise ValueError("When providing `credentials_strategy` field, `credential_provider` cannot be specified.")
if credentials_provider:
Expand Down Expand Up @@ -284,11 +284,11 @@ def parse_dsn(dsn: str) -> "Config":
if attr.name not in query:
continue
kwargs[attr.name] = query[attr.name]
return Config(**kwargs)
return Config(**kwargs) # type: ignore[arg-type]

def authenticate(self) -> Dict[str, str]:
"""Returns a list of fresh authentication headers"""
return self._header_factory()
return self._header_factory() # type: ignore[misc]

def as_dict(self) -> dict:
return self._inner
Expand All @@ -314,7 +314,7 @@ def environment(self) -> DatabricksEnvironment:
for environment in ALL_ENVS:
if environment.cloud != Cloud.AZURE:
continue
if environment.azure_environment.name != azure_env:
if environment.azure_environment.name != azure_env: # type: ignore[union-attr]
continue
if environment.dns_zone.startswith(".dev") or environment.dns_zone.startswith(".staging"):
continue
Expand Down Expand Up @@ -343,7 +343,7 @@ def is_account_client(self) -> bool:

@property
def arm_environment(self) -> AzureEnvironment:
return self.environment.azure_environment
return self.environment.azure_environment # type: ignore[return-value]

@property
def effective_azure_login_app_id(self):
Expand Down Expand Up @@ -414,11 +414,11 @@ def debug_string(self) -> str:
buf.append(f"Env: {', '.join(envs_used)}")
return ". ".join(buf)

def to_dict(self) -> Dict[str, any]:
def to_dict(self) -> Dict[str, any]: # type: ignore[valid-type]
return self._inner

@property
def sql_http_path(self) -> Optional[str]:
def sql_http_path(self) -> Optional[str]: # type: ignore[return]
"""(Experimental) Return HTTP path for SQL Drivers.

If `cluster_id` or `warehouse_id` are configured, return a valid HTTP Path argument
Expand Down Expand Up @@ -465,8 +465,8 @@ def attributes(cls) -> Iterable[ConfigAttribute]:
v.name = name
v.transform = anno.get(name, str)
attrs.append(v)
cls._attributes = attrs
return cls._attributes
cls._attributes = attrs # type: ignore[attr-defined]
return cls._attributes # type: ignore[attr-defined]

def _fix_host_if_needed(self):
updated_host = _fix_host_if_needed(self.host)
Expand Down Expand Up @@ -499,7 +499,7 @@ def load_azure_tenant_id(self):
self.azure_tenant_id = path_segments[1]
logger.debug(f"Loaded tenant ID: {self.azure_tenant_id}")

def _set_inner_config(self, keyword_args: Dict[str, any]):
def _set_inner_config(self, keyword_args: Dict[str, any]): # type: ignore[valid-type]
for attr in self.attributes():
if attr.name not in keyword_args:
continue
Expand Down
4 changes: 2 additions & 2 deletions databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ._base_client import _BaseClient
from .config import *
# To preserve backwards compatibility (as these definitions were previously in this module)
from .credentials_provider import *
from .credentials_provider import * # type: ignore[no-redef]
from .errors import DatabricksError, _ErrorCustomizer
from .oauth import retrieve_token

Expand Down Expand Up @@ -80,7 +80,7 @@ def do(
if url is None:
# Remove extra `/` from path for Files API
# Once we've fixed the OpenAPI spec, we can remove this
path = re.sub("^/api/2.0/fs/files//", "/api/2.0/fs/files/", path)
path = re.sub("^/api/2.0/fs/files//", "/api/2.0/fs/files/", path) # type: ignore[arg-type]
url = f"{self._cfg.host}{path}"
return self._api_client.do(
method=method,
Expand Down
Loading
Loading