Skip to content

Commit 31ad668

Browse files
fix ml error msg
1 parent fffeacd commit 31ad668

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

geoengine/ml.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,12 @@ def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix:
103103
if not data_type.tensor_type:
104104
raise InputException('Only tensor input types are supported')
105105
elem_type = data_type.tensor_type.elem_type
106-
if elem_type != RASTER_TYPE_TO_ONNX_TYPE[expected_type]:
106+
expected_tensor_type = RASTER_TYPE_TO_ONNX_TYPE[expected_type]
107+
if elem_type != expected_tensor_type:
107108
elem_type_str = tensor_dtype_to_string(elem_type)
109+
expected_type_str = tensor_dtype_to_string(expected_tensor_type)
108110
raise InputException(f'Model {prefix} type `{elem_type_str}` does not match the '
109-
f'expected type `{expected_type}`')
111+
f'expected type `{expected_type_str}`')
110112

111113
model_inputs = onnx_model.graph.input
112114
model_outputs = onnx_model.graph.output

tests/test_ml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_uploading_onnx_model(self):
106106
)
107107
self.assertEqual(
108108
str(exception.exception),
109-
'Model input type `TensorProto.FLOAT` does not match the expected type `F64`'
109+
'Model input type `TensorProto.FLOAT` does not match the expected type `TensorProto.DOUBLE`'
110110
)
111111

112112
with self.assertRaises(ge.InputException) as exception:
@@ -126,5 +126,5 @@ def test_uploading_onnx_model(self):
126126
)
127127
self.assertEqual(
128128
str(exception.exception),
129-
'Model output type `TensorProto.INT64` does not match the expected type `I32`'
129+
'Model output type `TensorProto.INT64` does not match the expected type `TensorProto.INT32`'
130130
)

0 commit comments

Comments
 (0)