Skip to content

Commit afe0e2f

Browse files
fix formatting and typing
1 parent 5ec4475 commit afe0e2f

File tree

6 files changed

+58
-28
lines changed

6 files changed

+58
-28
lines changed

geoengine/colorizer.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,13 @@ def from_response(response: geoengine_openapi_client.Colorizer) -> Colorizer:
230230
raise TypeError("Unknown colorizer type")
231231

232232

233+
def rgba_from_list(values: list[int]) -> Rgba:
234+
"""Convert a list of integers to an RGBA tuple."""
235+
if len(values) != 4:
236+
raise ValueError(f"Expected a list of 4 integers, got {len(values)} instead.")
237+
return (values[0], values[1], values[2], values[3])
238+
239+
233240
@dataclass
234241
class LinearGradientColorizer(Colorizer):
235242
'''A linear gradient colorizer.'''
@@ -242,10 +249,10 @@ def from_response_linear(response: geoengine_openapi_client.LinearGradient) -> L
242249
"""Create a colorizer from a response."""
243250
breakpoints = [ColorBreakpoint.from_response(breakpoint) for breakpoint in response.breakpoints]
244251
return LinearGradientColorizer(
245-
no_data_color=response.no_data_color,
252+
no_data_color=rgba_from_list(response.no_data_color),
246253
breakpoints=breakpoints,
247-
over_color=response.over_color,
248-
under_color=response.under_color,
254+
over_color=rgba_from_list(response.over_color),
255+
under_color=rgba_from_list(response.under_color),
249256
)
250257

251258
def to_api_dict(self) -> geoengine_openapi_client.Colorizer:
@@ -273,9 +280,9 @@ def from_response_logarithmic(
273280
breakpoints = [ColorBreakpoint.from_response(breakpoint) for breakpoint in response.breakpoints]
274281
return LogarithmicGradientColorizer(
275282
breakpoints=breakpoints,
276-
no_data_color=response.no_data_color,
277-
over_color=response.over_color,
278-
under_color=response.under_color,
283+
no_data_color=rgba_from_list(response.no_data_color),
284+
over_color=rgba_from_list(response.over_color),
285+
under_color=rgba_from_list(response.under_color),
279286
)
280287

281288
def to_api_dict(self) -> geoengine_openapi_client.Colorizer:
@@ -300,16 +307,16 @@ def from_response_palette(response: geoengine_openapi_client.PaletteColorizer) -
300307
"""Create a colorizer from a response."""
301308

302309
return PaletteColorizer(
303-
colors={float(k): v for k, v in response.colors.items()},
304-
no_data_color=response.no_data_color,
305-
default_color=response.default_color,
310+
colors={float(k): rgba_from_list(v) for k, v in response.colors.items()},
311+
no_data_color=rgba_from_list(response.no_data_color),
312+
default_color=rgba_from_list(response.default_color),
306313
)
307314

308315
def to_api_dict(self) -> geoengine_openapi_client.Colorizer:
309316
"""Return the colorizer as a dictionary."""
310317
return geoengine_openapi_client.Colorizer(geoengine_openapi_client.PaletteColorizer(
311318
type='palette',
312-
colors={str(k): v for k,v in self.colors.items()},
319+
colors={str(k): v for k, v in self.colors.items()},
313320
default_color=self.default_color,
314321
no_data_color=self.no_data_color,
315322
))

geoengine/datasets.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,14 @@ def upload_dataframe(
466466
ints = [key for (key, value) in columns.items() if value.data_type == 'int']
467467
texts = [key for (key, value) in columns.items() if value.data_type == 'text']
468468

469+
result_descriptor = VectorResultDescriptor(
470+
data_type=vector_type,
471+
spatial_reference=df.crs.to_string(),
472+
columns=columns,
473+
).to_api_dict().actual_instance
474+
if not isinstance(result_descriptor, geoengine_openapi_client.TypedVectorResultDescriptor):
475+
raise TypeError('Expected TypedVectorResultDescriptor')
476+
469477
create = geoengine_openapi_client.CreateDataset(
470478
data_path=geoengine_openapi_client.DataPath(geoengine_openapi_client.DataPathOneOf1(
471479
upload=str(upload_id)
@@ -494,11 +502,9 @@ def upload_dataframe(
494502
),
495503
on_error=on_error.to_api_enum(),
496504
),
497-
result_descriptor=geoengine_openapi_client.VectorResultDescriptor.from_dict(VectorResultDescriptor(
498-
data_type=vector_type,
499-
spatial_reference=df.crs.to_string(),
500-
columns=columns,
501-
).to_api_dict().actual_instance.to_dict())
505+
result_descriptor=geoengine_openapi_client.VectorResultDescriptor.from_dict(
506+
result_descriptor.to_dict()
507+
)
502508
)
503509
)
504510
)

geoengine/error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, response: Union[geoengine_openapi_client.ApiException, Dict[s
2121
super().__init__()
2222

2323
if isinstance(response, geoengine_openapi_client.ApiException):
24-
obj = json.loads(response.body)
24+
obj = json.loads(response.body) if response.body else {'error': 'unknown', 'message': 'unknown'}
2525
else:
2626
obj = response
2727

geoengine/types.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,29 @@ def __repr__(self) -> str:
660660
return f'{self.name}: {self.measurement}'
661661

662662

663+
def literal_raster_data_type(
664+
data_type: geoengine_openapi_client.RasterDataType
665+
) -> Literal['U8', 'U16', 'U32', 'U64', 'I8', 'I16', 'I32', 'I64', 'F32', 'F64']:
666+
'''Convert a `RasterDataType` to a literal'''
667+
668+
data_type_map: dict[
669+
geoengine_openapi_client.RasterDataType,
670+
Literal['U8', 'U16', 'U32', 'U64', 'I8', 'I16', 'I32', 'I64', 'F32', 'F64']
671+
] = {
672+
geoengine_openapi_client.RasterDataType.U8: 'U8',
673+
geoengine_openapi_client.RasterDataType.U16: 'U16',
674+
geoengine_openapi_client.RasterDataType.U32: 'U32',
675+
geoengine_openapi_client.RasterDataType.U64: 'U64',
676+
geoengine_openapi_client.RasterDataType.I8: 'I8',
677+
geoengine_openapi_client.RasterDataType.I16: 'I16',
678+
geoengine_openapi_client.RasterDataType.I32: 'I32',
679+
geoengine_openapi_client.RasterDataType.I64: 'I64',
680+
geoengine_openapi_client.RasterDataType.F32: 'F32',
681+
geoengine_openapi_client.RasterDataType.F64: 'F64',
682+
}
683+
return data_type_map[data_type]
684+
685+
663686
class RasterResultDescriptor(ResultDescriptor):
664687
'''
665688
A raster result descriptor
@@ -701,7 +724,7 @@ def from_response_raster(
701724
response: geoengine_openapi_client.TypedRasterResultDescriptor) -> RasterResultDescriptor:
702725
'''Parse a raster result descriptor from an http response'''
703726
spatial_ref = response.spatial_reference
704-
data_type = response.data_type.value
727+
data_type = literal_raster_data_type(response.data_type)
705728
bands = [RasterBandDescriptor.from_response(band) for band in response.bands]
706729

707730
time_bounds = None

geoengine/workflow.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -535,13 +535,6 @@ def save_as_dataset(
535535

536536
session = get_session()
537537

538-
print(geoengine_openapi_client.RasterDatasetFromWorkflow(
539-
name=name,
540-
display_name=display_name,
541-
description=description,
542-
query=query_rectangle
543-
).to_json())
544-
545538
with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
546539
workflows_api = geoengine_openapi_client.WorkflowsApi(api_client)
547540
response = workflows_api.dataset_from_workflow_handler(

geoengine/workflow_builder/operators.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -694,11 +694,12 @@ def from_operator_dict(cls, operator_dict: Dict[str, Any]) -> 'Expression':
694694

695695
output_band = None
696696
if "outputBand" in operator_dict["params"] and operator_dict["params"]["outputBand"] is not None:
697-
output_band = RasterBandDescriptor.from_response(
698-
geoengine_openapi_client.RasterBandDescriptor.from_dict(
699-
operator_dict["params"]["outputBand"]
700-
)
697+
raster_band_descriptor = geoengine_openapi_client.RasterBandDescriptor.from_dict(
698+
operator_dict["params"]["outputBand"]
701699
)
700+
if raster_band_descriptor is None:
701+
raise ValueError("Invalid output band")
702+
output_band = RasterBandDescriptor.from_response(raster_band_descriptor)
702703

703704
return Expression(
704705
expression=operator_dict["params"]["expression"],

0 commit comments

Comments
 (0)