diff --git a/cryosparc/api.py b/cryosparc/api.py index e1f9c1c8..bcbe37c0 100644 --- a/cryosparc/api.py +++ b/cryosparc/api.py @@ -7,6 +7,7 @@ from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, TypedDict, Union import httpx +from pydantic import BaseModel from . import registry from .errors import APIError @@ -392,9 +393,9 @@ def _decode_json_response(value: Any, schema: dict): return model_class(value) elif model_class and issubclass(model_class, dict): # typed dict return model_class(**value) - elif model_class: # pydantic model + elif model_class and issubclass(model_class, BaseModel): # pydantic model # use model_validate in case validator result derives from subtype, e.g., Event model - return model_class.model_validate(value) # type: ignore + return model_class.model_validate(value) warnings.warn( f"[API] Warning: Received API response with unregistered schema type {schema['$ref']}. " "Returning as plain object." diff --git a/cryosparc/dataset/dtype.py b/cryosparc/dataset/dtype.py index 3d8811ea..102a697a 100644 --- a/cryosparc/dataset/dtype.py +++ b/cryosparc/dataset/dtype.py @@ -35,7 +35,7 @@ class DatasetHeader(TypedDict): """Field names that require decompression.""" -DSET_TO_TYPE_MAP: Dict[DsetType, Type] = { +DSET_TO_TYPE_MAP: Dict[DsetType, Type[Union[n.number, n.object_]]] = { DsetType.T_F32: n.float32, DsetType.T_F64: n.float64, DsetType.T_C32: n.complex64, @@ -107,8 +107,10 @@ def get_data_field(data: Data, field: str) -> Field: def get_data_field_dtype(data: Data, field: str) -> "DTypeLike": t = data.type(field) - if t == 0 or t not in DSET_TO_TYPE_MAP: - raise KeyError(f"Unknown dataset field {field} or field type {t}") + if t == 0: + raise KeyError(f"Unknown dataset field {field}") + elif t not in DSET_TO_TYPE_MAP: + raise KeyError(f"Unknown dataset field type {t}") dt = n.dtype(DSET_TO_TYPE_MAP[t]) shape = data.getshp(field) return (dt.str, shape) if shape else dt.str diff --git a/tests/conftest.py b/tests/conftest.py index 41e7a3e9..b8b46048 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -255,7 +255,13 @@ def mock_api_client_class(mock_user, monkeypatch): @pytest.fixture def cs(mock_api_client_class): - return CryoSPARC("https://cryosparc.example.com", email="structura@example.com", password="password") + return CryoSPARC( + "https://cryosparc.example.com", + email="structura@example.com", + password="password", + host=None, + base_port=None, + ) @pytest.fixture diff --git a/tests/test_cli.py b/tests/test_cli.py index 21bde976..6468b3a7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -31,7 +31,7 @@ def test_cli_login(mock_api_client_class, mock_auth_path): def test_cli_login_auth(mock_user, mock_api_client_class, mock_auth_path): - cs = CryoSPARC("https://cryosparc.example.com", email="structura@example.com") + cs = CryoSPARC("https://cryosparc.example.com", email="structura@example.com", host=None, base_port=None) mock_api_client_class.__call__.assert_called_with(auth="abc123") # called with token assert cs.user == mock_user assert cs.test_connection()