Skip to content

Commit 3050ccb

Browse files
Merge pull request #39 from geo-engine/ml-model-input-outpt-shape-2
Add ml model input and output shape to allow models run on entire tiles
2 parents 66dfd0e + f3c5cb9 commit 3050ccb

26 files changed

+577
-31
lines changed

.generation/config.ini

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
[input]
2-
backendCommit = 1076a616369dcc33e86b422a9364ac99553a18f8
2+
backendCommit = cdb162df11bff5a3ae5854126d15538c77a2cabb
33

44
[general]
55
githubUrl = https://github.com/geo-engine/openapi-client
6-
version = 0.0.23
6+
version = 0.0.24
77

88
[python]
99
name = geoengine_openapi_client

.generation/input/openapi.json

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6953,20 +6953,22 @@
69536953
"required": [
69546954
"fileName",
69556955
"inputType",
6956-
"numInputBands",
6957-
"outputType"
6956+
"outputType",
6957+
"inputShape",
6958+
"outputShape"
69586959
],
69596960
"properties": {
69606961
"fileName": {
69616962
"type": "string"
69626963
},
6964+
"inputShape": {
6965+
"$ref": "#/components/schemas/MlTensorShape3D"
6966+
},
69636967
"inputType": {
69646968
"$ref": "#/components/schemas/RasterDataType"
69656969
},
6966-
"numInputBands": {
6967-
"type": "integer",
6968-
"format": "int32",
6969-
"minimum": 0
6970+
"outputShape": {
6971+
"$ref": "#/components/schemas/MlTensorShape3D"
69706972
},
69716973
"outputType": {
69726974
"$ref": "#/components/schemas/RasterDataType"
@@ -7005,6 +7007,32 @@
70057007
}
70067008
}
70077009
},
7010+
"MlTensorShape3D": {
7011+
"type": "object",
7012+
"description": "A struct describing tensor shape for `MlModelMetadata`",
7013+
"required": [
7014+
"y",
7015+
"x",
7016+
"bands"
7017+
],
7018+
"properties": {
7019+
"bands": {
7020+
"type": "integer",
7021+
"format": "int32",
7022+
"minimum": 0
7023+
},
7024+
"x": {
7025+
"type": "integer",
7026+
"format": "int32",
7027+
"minimum": 0
7028+
},
7029+
"y": {
7030+
"type": "integer",
7031+
"format": "int32",
7032+
"minimum": 0
7033+
}
7034+
}
7035+
},
70087036
"MockDatasetDataSourceLoadingInfo": {
70097037
"type": "object",
70107038
"required": [

python/.openapi-generator/FILES

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ docs/MlModel.md
8787
docs/MlModelMetadata.md
8888
docs/MlModelNameResponse.md
8989
docs/MlModelResource.md
90+
docs/MlTensorShape3D.md
9091
docs/MockDatasetDataSourceLoadingInfo.md
9192
docs/MockMetaData.md
9293
docs/MultiBandRasterColorizer.md
@@ -338,6 +339,7 @@ geoengine_openapi_client/models/ml_model.py
338339
geoengine_openapi_client/models/ml_model_metadata.py
339340
geoengine_openapi_client/models/ml_model_name_response.py
340341
geoengine_openapi_client/models/ml_model_resource.py
342+
geoengine_openapi_client/models/ml_tensor_shape3_d.py
341343
geoengine_openapi_client/models/mock_dataset_data_source_loading_info.py
342344
geoengine_openapi_client/models/mock_meta_data.py
343345
geoengine_openapi_client/models/multi_band_raster_colorizer.py
@@ -567,6 +569,7 @@ test/test_ml_model.py
567569
test/test_ml_model_metadata.py
568570
test/test_ml_model_name_response.py
569571
test/test_ml_model_resource.py
572+
test/test_ml_tensor_shape3_d.py
570573
test/test_mock_dataset_data_source_loading_info.py
571574
test/test_mock_meta_data.py
572575
test/test_multi_band_raster_colorizer.py

python/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ Class | Method | HTTP request | Description
267267
- [MlModelMetadata](docs/MlModelMetadata.md)
268268
- [MlModelNameResponse](docs/MlModelNameResponse.md)
269269
- [MlModelResource](docs/MlModelResource.md)
270+
- [MlTensorShape3D](docs/MlTensorShape3D.md)
270271
- [MockDatasetDataSourceLoadingInfo](docs/MockDatasetDataSourceLoadingInfo.md)
271272
- [MockMetaData](docs/MockMetaData.md)
272273
- [MultiBandRasterColorizer](docs/MultiBandRasterColorizer.md)

python/geoengine_openapi_client/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
from geoengine_openapi_client.models.ml_model_metadata import MlModelMetadata
128128
from geoengine_openapi_client.models.ml_model_name_response import MlModelNameResponse
129129
from geoengine_openapi_client.models.ml_model_resource import MlModelResource
130+
from geoengine_openapi_client.models.ml_tensor_shape3_d import MlTensorShape3D
130131
from geoengine_openapi_client.models.mock_dataset_data_source_loading_info import MockDatasetDataSourceLoadingInfo
131132
from geoengine_openapi_client.models.mock_meta_data import MockMetaData
132133
from geoengine_openapi_client.models.multi_band_raster_colorizer import MultiBandRasterColorizer

python/geoengine_openapi_client/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
from geoengine_openapi_client.models.ml_model_metadata import MlModelMetadata
9696
from geoengine_openapi_client.models.ml_model_name_response import MlModelNameResponse
9797
from geoengine_openapi_client.models.ml_model_resource import MlModelResource
98+
from geoengine_openapi_client.models.ml_tensor_shape3_d import MlTensorShape3D
9899
from geoengine_openapi_client.models.mock_dataset_data_source_loading_info import MockDatasetDataSourceLoadingInfo
99100
from geoengine_openapi_client.models.mock_meta_data import MockMetaData
100101
from geoengine_openapi_client.models.multi_band_raster_colorizer import MultiBandRasterColorizer

python/geoengine_openapi_client/models/ml_model_metadata.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from pydantic import BaseModel, ConfigDict, Field, StrictStr
2222
from typing import Any, ClassVar, Dict, List
23-
from typing_extensions import Annotated
23+
from geoengine_openapi_client.models.ml_tensor_shape3_d import MlTensorShape3D
2424
from geoengine_openapi_client.models.raster_data_type import RasterDataType
2525
from typing import Optional, Set
2626
from typing_extensions import Self
@@ -30,10 +30,11 @@ class MlModelMetadata(BaseModel):
3030
MlModelMetadata
3131
""" # noqa: E501
3232
file_name: StrictStr = Field(alias="fileName")
33+
input_shape: MlTensorShape3D = Field(alias="inputShape")
3334
input_type: RasterDataType = Field(alias="inputType")
34-
num_input_bands: Annotated[int, Field(strict=True, ge=0)] = Field(alias="numInputBands")
35+
output_shape: MlTensorShape3D = Field(alias="outputShape")
3536
output_type: RasterDataType = Field(alias="outputType")
36-
__properties: ClassVar[List[str]] = ["fileName", "inputType", "numInputBands", "outputType"]
37+
__properties: ClassVar[List[str]] = ["fileName", "inputShape", "inputType", "outputShape", "outputType"]
3738

3839
model_config = ConfigDict(
3940
populate_by_name=True,
@@ -74,6 +75,12 @@ def to_dict(self) -> Dict[str, Any]:
7475
exclude=excluded_fields,
7576
exclude_none=True,
7677
)
78+
# override the default output from pydantic by calling `to_dict()` of input_shape
79+
if self.input_shape:
80+
_dict['inputShape'] = self.input_shape.to_dict()
81+
# override the default output from pydantic by calling `to_dict()` of output_shape
82+
if self.output_shape:
83+
_dict['outputShape'] = self.output_shape.to_dict()
7784
return _dict
7885

7986
@classmethod
@@ -87,8 +94,9 @@ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
8794

8895
_obj = cls.model_validate({
8996
"fileName": obj.get("fileName"),
97+
"inputShape": MlTensorShape3D.from_dict(obj["inputShape"]) if obj.get("inputShape") is not None else None,
9098
"inputType": obj.get("inputType"),
91-
"numInputBands": obj.get("numInputBands"),
99+
"outputShape": MlTensorShape3D.from_dict(obj["outputShape"]) if obj.get("outputShape") is not None else None,
92100
"outputType": obj.get("outputType")
93101
})
94102
return _obj
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# coding: utf-8
2+
3+
"""
4+
Geo Engine API
5+
6+
No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
7+
8+
The version of the OpenAPI document: 0.8.0
9+
Contact: dev@geoengine.de
10+
Generated by OpenAPI Generator (https://openapi-generator.tech)
11+
12+
Do not edit the class manually.
13+
""" # noqa: E501
14+
15+
16+
from __future__ import annotations
17+
import pprint
18+
import re # noqa: F401
19+
import json
20+
21+
from pydantic import BaseModel, ConfigDict, Field
22+
from typing import Any, ClassVar, Dict, List
23+
from typing_extensions import Annotated
24+
from typing import Optional, Set
25+
from typing_extensions import Self
26+
27+
class MlTensorShape3D(BaseModel):
28+
"""
29+
A struct describing tensor shape for `MlModelMetadata`
30+
""" # noqa: E501
31+
bands: Annotated[int, Field(strict=True, ge=0)]
32+
x: Annotated[int, Field(strict=True, ge=0)]
33+
y: Annotated[int, Field(strict=True, ge=0)]
34+
__properties: ClassVar[List[str]] = ["bands", "x", "y"]
35+
36+
model_config = ConfigDict(
37+
populate_by_name=True,
38+
validate_assignment=True,
39+
protected_namespaces=(),
40+
)
41+
42+
43+
def to_str(self) -> str:
44+
"""Returns the string representation of the model using alias"""
45+
return pprint.pformat(self.model_dump(by_alias=True))
46+
47+
def to_json(self) -> str:
48+
"""Returns the JSON representation of the model using alias"""
49+
# TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
50+
return json.dumps(self.to_dict())
51+
52+
@classmethod
53+
def from_json(cls, json_str: str) -> Optional[Self]:
54+
"""Create an instance of MlTensorShape3D from a JSON string"""
55+
return cls.from_dict(json.loads(json_str))
56+
57+
def to_dict(self) -> Dict[str, Any]:
58+
"""Return the dictionary representation of the model using alias.
59+
60+
This has the following differences from calling pydantic's
61+
`self.model_dump(by_alias=True)`:
62+
63+
* `None` is only added to the output dict for nullable fields that
64+
were set at model initialization. Other fields with value `None`
65+
are ignored.
66+
"""
67+
excluded_fields: Set[str] = set([
68+
])
69+
70+
_dict = self.model_dump(
71+
by_alias=True,
72+
exclude=excluded_fields,
73+
exclude_none=True,
74+
)
75+
return _dict
76+
77+
@classmethod
78+
def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
79+
"""Create an instance of MlTensorShape3D from a dict"""
80+
if obj is None:
81+
return None
82+
83+
if not isinstance(obj, dict):
84+
return cls.model_validate(obj)
85+
86+
_obj = cls.model_validate({
87+
"bands": obj.get("bands"),
88+
"x": obj.get("x"),
89+
"y": obj.get("y")
90+
})
91+
return _obj
92+
93+

python/test/test_ml_model.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,15 @@ def make_instance(self, include_optional) -> MlModel:
4040
display_name = '',
4141
metadata = geoengine_openapi_client.models.ml_model_metadata.MlModelMetadata(
4242
file_name = '',
43+
input_shape = geoengine_openapi_client.models.ml_tensor_shape3_d.MlTensorShape3D(
44+
bands = 0,
45+
x = 0,
46+
y = 0, ),
4347
input_type = 'U8',
44-
num_input_bands = 0,
48+
output_shape = geoengine_openapi_client.models.ml_tensor_shape3_d.MlTensorShape3D(
49+
bands = 0,
50+
x = 0,
51+
y = 0, ),
4552
output_type = 'U8', ),
4653
name = '',
4754
upload = ''
@@ -52,8 +59,15 @@ def make_instance(self, include_optional) -> MlModel:
5259
display_name = '',
5360
metadata = geoengine_openapi_client.models.ml_model_metadata.MlModelMetadata(
5461
file_name = '',
62+
input_shape = geoengine_openapi_client.models.ml_tensor_shape3_d.MlTensorShape3D(
63+
bands = 0,
64+
x = 0,
65+
y = 0, ),
5566
input_type = 'U8',
56-
num_input_bands = 0,
67+
output_shape = geoengine_openapi_client.models.ml_tensor_shape3_d.MlTensorShape3D(
68+
bands = 0,
69+
x = 0,
70+
y = 0, ),
5771
output_type = 'U8', ),
5872
name = '',
5973
upload = '',

python/test/test_ml_model_metadata.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,29 @@ def make_instance(self, include_optional) -> MlModelMetadata:
3737
if include_optional:
3838
return MlModelMetadata(
3939
file_name = '',
40+
input_shape = geoengine_openapi_client.models.ml_tensor_shape3_d.MlTensorShape3D(
41+
bands = 0,
42+
x = 0,
43+
y = 0, ),
4044
input_type = 'U8',
41-
num_input_bands = 0,
45+
output_shape = geoengine_openapi_client.models.ml_tensor_shape3_d.MlTensorShape3D(
46+
bands = 0,
47+
x = 0,
48+
y = 0, ),
4249
output_type = 'U8'
4350
)
4451
else:
4552
return MlModelMetadata(
4653
file_name = '',
54+
input_shape = geoengine_openapi_client.models.ml_tensor_shape3_d.MlTensorShape3D(
55+
bands = 0,
56+
x = 0,
57+
y = 0, ),
4758
input_type = 'U8',
48-
num_input_bands = 0,
59+
output_shape = geoengine_openapi_client.models.ml_tensor_shape3_d.MlTensorShape3D(
60+
bands = 0,
61+
x = 0,
62+
y = 0, ),
4963
output_type = 'U8',
5064
)
5165
"""

0 commit comments

Comments
 (0)