|
10 | 10 | Sequence,
|
11 | 11 | Tuple,
|
12 | 12 | Union,
|
| 13 | + Hashable, |
| 14 | + cast, |
13 | 15 | )
|
14 | 16 |
|
15 | 17 | import ipyleaflet # type: ignore
|
@@ -49,7 +51,7 @@ def __init__(
|
49 | 51 | # data resources, and so column names and dtype need to be
|
50 | 52 | # passed in as parameters.
|
51 | 53 | self._aim_metadata_columns: Optional[List[str]] = None
|
52 |
| - self._aim_metadata_dtype: Dict[str, Any] = dict() |
| 54 | + self._aim_metadata_dtype: Dict[str, Union[str, type, np.dtype]] = dict() |
53 | 55 | if isinstance(aim_metadata_dtype, Mapping):
|
54 | 56 | self._aim_metadata_columns = list(aim_metadata_dtype.keys())
|
55 | 57 | self._aim_metadata_dtype.update(aim_metadata_dtype)
|
@@ -150,7 +152,19 @@ def _parse_general_metadata(
|
150 | 152 | "longitude": "float64",
|
151 | 153 | "sex_call": "object",
|
152 | 154 | }
|
153 |
| - df = pd.read_csv(io.BytesIO(data), dtype=dtype, na_values="") |
| 155 | + # Mapping of string dtypes to actual dtypes |
| 156 | + dtype_map = { |
| 157 | + "object": str, |
| 158 | + "int64": np.int64, |
| 159 | + "float64": np.float64, |
| 160 | + } |
| 161 | + |
| 162 | + # Convert string dtypes to actual dtypes |
| 163 | + dtype_fixed: Mapping[Hashable, Union[str, np.dtype, type]] = { |
| 164 | + col: dtype_map.get(dtype[col], str) for col in dtype |
| 165 | + } |
| 166 | + |
| 167 | + df = pd.read_csv(io.BytesIO(data), dtype=dtype_fixed, na_values="") |
154 | 168 |
|
155 | 169 | # Ensure all column names are lower case.
|
156 | 170 | df.columns = [c.lower() for c in df.columns] # type: ignore
|
@@ -470,7 +484,12 @@ def _parse_aim_metadata(
|
470 | 484 | if isinstance(data, bytes):
|
471 | 485 | # Parse CSV data.
|
472 | 486 | df = pd.read_csv(
|
473 |
| - io.BytesIO(data), dtype=self._aim_metadata_dtype, na_values="" |
| 487 | + io.BytesIO(data), |
| 488 | + dtype=cast( |
| 489 | + Mapping[Hashable, Union[str, type, np.dtype]], |
| 490 | + self._aim_metadata_dtype, |
| 491 | + ), |
| 492 | + na_values="", |
474 | 493 | )
|
475 | 494 |
|
476 | 495 | # Ensure all column names are lower case.
|
@@ -901,9 +920,7 @@ def _prep_sample_selection_cache_params(
|
901 | 920 | # integer indices instead.
|
902 | 921 | df_samples = self.sample_metadata(sample_sets=sample_sets)
|
903 | 922 | sample_query_options = sample_query_options or {}
|
904 |
| - loc_samples = ( |
905 |
| - df_samples.eval(sample_query, **sample_query_options).values, |
906 |
| - ) |
| 923 | + loc_samples = df_samples.eval(sample_query, **sample_query_options).values |
907 | 924 | sample_indices = np.nonzero(loc_samples)[0].tolist()
|
908 | 925 |
|
909 | 926 | return sample_sets, sample_indices
|
|
0 commit comments