Skip to content

Commit afc4879

Browse files
authored
Merge pull request #4635 from thewtex/transform-dict-repr
BUG: Make dict_from_transform more consistent with other dict representations
2 parents 0ea4ae9 + 5397f81 commit afc4879

File tree

4 files changed

+80
-44
lines changed

4 files changed

+80
-44
lines changed

Modules/Core/Transform/wrapping/test/itkTransformSerializationTest.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,13 @@
3535

3636
keys_to_test1 = [
3737
"name",
38-
"parametersValueType",
39-
"transformType",
40-
"inputDimension",
41-
"outputDimension",
4238
"inputSpaceName",
4339
"outputSpaceName",
4440
"numberOfParameters",
4541
"numberOfFixedParameters",
4642
]
4743
keys_to_test2 = ["parameters", "fixedParameters"]
44+
keys_to_test3 = ["transformParameterization", "parametersValueType", "inputDimension", "outputDimension"]
4845

4946
transform_object_list = []
5047
for i, transform_type in enumerate(transforms_to_test):
@@ -60,6 +57,8 @@
6057
# Test all the parameters
6158
for k in keys_to_test2:
6259
assert np.array_equal(serialize_deserialize[k], transform[k])
60+
for k in keys_to_test3:
61+
assert serialize_deserialize["transformType"][k], transform["transformType"][k]
6362
transform_object_list.append(transform)
6463

6564
print("Individual Transforms Test Done")
@@ -93,6 +92,9 @@
9392
for k in keys_to_test2:
9493
assert np.array_equal(transform_obj[k], transform_object_list[i][k])
9594

95+
for k in keys_to_test3:
96+
assert transform_object_list[i]["transformType"][k], transform["transformType"][k]
97+
9698

9799
# Test for transformation using de-serialized BSpline Transform
98100
ImageDimension = 2

Wrapping/Generators/Python/PyBase/pyBase.i

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -430,15 +430,15 @@ str = str
430430
Return keys related to the transform's metadata.
431431
These keys are used in the dictionary resulting from dict(transform).
432432
"""
433-
result = ['name', 'inputDimension', 'outputDimension', 'inputSpaceName', 'outputSpaceName', 'numberOfParameters', 'numberOfFixedParameters', 'parameters', 'fixedParameters']
433+
result = ['transformType', 'name', 'inputSpaceName', 'outputSpaceName', 'numberOfParameters', 'numberOfFixedParameters', 'parameters', 'fixedParameters']
434434
return result
435435
436436
def __getitem__(self, key):
437437
"""Access metadata keys, see help(transform.keys), for string keys."""
438438
import itk
439439
if isinstance(key, str):
440440
state = itk.dict_from_transform(self)
441-
return state[0][key]
441+
return state[key]
442442
443443
def __setitem__(self, key, value):
444444
if isinstance(key, str):
@@ -474,7 +474,6 @@ str = str
474474
def __setstate__(self, state):
475475
"""Set object state, necessary for serialization with pickle."""
476476
import itk
477-
import numpy as np
478477
deserialized = itk.transform_from_dict(state)
479478
self.__dict__['this'] = deserialized
480479
%}

Wrapping/Generators/Python/Tests/extras.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,11 @@ def custom_callback(name, progress):
343343
parameters = np.asarray(transforms[0].GetParameters())
344344
assert np.allclose(parameters, np.array(baseline_additional_transform_params))
345345

346+
transform_dict = itk.dict_from_transform(transforms[0])
347+
transform_back = itk.transform_from_dict(transform_dict)
348+
transform_dict = itk.dict_from_transform(transforms)
349+
transform_back = itk.transform_from_dict(transform_dict)
350+
346351
# pipeline, auto_pipeline and templated class are tested in other files
347352

348353
# BridgeNumPy

Wrapping/Generators/Python/itk/support/extras.py

Lines changed: 67 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -981,57 +981,81 @@ def dict_from_pointset(pointset: "itkt.PointSet") -> Dict:
981981
)
982982

983983

984-
def dict_from_transform(transform: "itkt.TransformBase") -> Dict:
984+
def dict_from_transform(transform: Union["itkt.TransformBase", List["itkt.TransformBase"]]) -> Union[List[Dict], Dict]:
985+
"""Serialize a Python itk.Transform object to a pickable Python dictionary.
986+
987+
If the transform is a list of transforms, then a list of dictionaries is returned.
988+
If the transform is a single, non-Composite transform, then a single dictionary is returned.
989+
Composite transforms and nested composite transforms are flattened into a list of dictionaries.
990+
"""
985991
import itk
992+
datatype_dict = {"double": itk.D, "float": itk.F}
986993

987994
def update_transform_dict(current_transform):
988995
current_transform_type = current_transform.GetTransformTypeAsString()
989996
current_transform_type_split = current_transform_type.split("_")
990-
component = itk.template(current_transform)
991997

992-
in_transform_dict = dict()
993-
in_transform_dict["name"] = current_transform.GetObjectName()
998+
transform_type = dict()
999+
transform_parameterization = current_transform_type_split[0].replace("Transform", "")
1000+
transform_type["transformParameterization"] = transform_parameterization
9941001

995-
datatype_dict = {"double": itk.D, "float": itk.F}
996-
in_transform_dict["parametersValueType"] = python_to_js(
1002+
transform_type["parametersValueType"] = python_to_js(
9971003
datatype_dict[current_transform_type_split[1]]
9981004
)
999-
in_transform_dict["inputDimension"] = int(current_transform_type_split[2])
1000-
in_transform_dict["outputDimension"] = int(current_transform_type_split[3])
1001-
in_transform_dict["transformType"] = current_transform_type_split[0]
1005+
transform_type["inputDimension"] = int(current_transform_type_split[2])
1006+
transform_type["outputDimension"] = int(current_transform_type_split[3])
10021007

1003-
in_transform_dict["inputSpaceName"] = current_transform.GetInputSpaceName()
1004-
in_transform_dict["outputSpaceName"] = current_transform.GetOutputSpaceName()
1008+
transform_dict = dict()
1009+
transform_dict['transformType'] = transform_type
1010+
transform_dict["name"] = current_transform.GetObjectName()
1011+
1012+
transform_dict["inputSpaceName"] = current_transform.GetInputSpaceName()
1013+
transform_dict["outputSpaceName"] = current_transform.GetOutputSpaceName()
10051014

10061015
# To avoid copying the parameters for the Composite Transform
10071016
# as it is a copy of child transforms.
10081017
if "Composite" not in current_transform_type_split[0]:
10091018
p = np.array(current_transform.GetParameters())
1010-
in_transform_dict["parameters"] = p
1019+
transform_dict["parameters"] = p
10111020

10121021
fp = np.array(current_transform.GetFixedParameters())
1013-
in_transform_dict["fixedParameters"] = fp
1022+
transform_dict["fixedParameters"] = fp
10141023

1015-
in_transform_dict["numberOfParameters"] = p.shape[0]
1016-
in_transform_dict["numberOfFixedParameters"] = fp.shape[0]
1024+
transform_dict["numberOfParameters"] = p.shape[0]
1025+
transform_dict["numberOfFixedParameters"] = fp.shape[0]
10171026

1018-
return in_transform_dict
1027+
return transform_dict
10191028

10201029
dict_array = []
1021-
transform_type = transform.GetTransformTypeAsString()
1022-
if "CompositeTransform" in transform_type:
1023-
# Add the transforms inside the composite transform
1024-
# range is over-ridden so using this hack to create a list
1025-
for i, _ in enumerate([0] * transform.GetNumberOfTransforms()):
1026-
current_transform = transform.GetNthTransform(i)
1027-
dict_array.append(update_transform_dict(current_transform))
1030+
multi = False
1031+
def add_transform_dict(transform):
1032+
transform_type = transform.GetTransformTypeAsString()
1033+
if "CompositeTransform" in transform_type:
1034+
# Add the transforms inside the composite transform
1035+
# range is over-ridden so using this hack to create a list
1036+
for i, _ in enumerate([0] * transform.GetNumberOfTransforms()):
1037+
current_transform = transform.GetNthTransform(i)
1038+
dict_array.append(update_transform_dict(current_transform))
1039+
return True
1040+
else:
1041+
dict_array.append(update_transform_dict(transform))
1042+
return False
1043+
if isinstance(transform, list):
1044+
multi = True
1045+
for t in transform:
1046+
add_transform_dict(t)
10281047
else:
1029-
dict_array.append(update_transform_dict(transform))
1048+
multi = add_transform_dict(transform)
10301049

1031-
return dict_array
1050+
if multi:
1051+
return dict_array
1052+
else:
1053+
return dict_array[0]
10321054

1055+
def transform_from_dict(transform_dict: Union[Dict, List[Dict]]) -> "itkt.TransformBase":
1056+
"""Deserialize a dictionary representing an itk.Transform object.
10331057
1034-
def transform_from_dict(transform_dict: Dict) -> "itkt.TransformBase":
1058+
If the dictionary represents a list of transforms, then a Composite Transform is returned."""
10351059
import itk
10361060

10371061
def set_parameters(transform, transform_parameters, transform_fixed_parameters, data_type):
@@ -1055,35 +1079,41 @@ def special_transform_check(transform_name):
10551079

10561080
parametersValueType_dict = {"float32": itk.F, "float64": itk.D}
10571081

1082+
if not isinstance(transform_dict, list):
1083+
transform_dict = [transform_dict]
1084+
10581085
# Loop over all the transforms in the dictionary
10591086
transforms_list = []
10601087
for i, _ in enumerate(transform_dict):
1061-
data_type = parametersValueType_dict[transform_dict[i]["parametersValueType"]]
1088+
transform_type = transform_dict[i]["transformType"]
1089+
data_type = parametersValueType_dict[transform_type["parametersValueType"]]
1090+
1091+
transform_parameterization = transform_type["transformParameterization"] + 'Transform'
10621092

10631093
# No template parameter needed for transforms having 2D or 3D name
10641094
# Also for some selected transforms
1065-
if special_transform_check(transform_dict[i]["transformType"]):
1066-
transform_template = getattr(itk, transform_dict[i]["transformType"])
1095+
if special_transform_check(transform_parameterization):
1096+
transform_template = getattr(itk, transform_parameterization)
10671097
transform = transform_template[data_type].New()
10681098
# Currently only BSpline Transform has 3 template parameters
10691099
# For future extensions the information will have to be encoded in
10701100
# the transformType variable. The transform object once added in a
10711101
# composite transform lose the information for other template parameters ex. BSpline.
10721102
# The Spline order is fixed as 3 here.
1073-
elif transform_dict[i]["transformType"] == "BSplineTransform":
1074-
transform_template = getattr(itk, transform_dict[i]["transformType"])
1103+
elif transform_parameterization == "BSplineTransform":
1104+
transform_template = getattr(itk, transform_parameterization)
10751105
transform = transform_template[
1076-
data_type, transform_dict[i]["inputDimension"], 3
1106+
data_type, transform_type["inputDimension"], 3
10771107
].New()
10781108
else:
1079-
transform_template = getattr(itk, transform_dict[i]["transformType"])
1109+
transform_template = getattr(itk, transform_parameterization)
10801110
if len(transform_template.items()[0][0]) > 2:
10811111
transform = transform_template[
1082-
data_type, transform_dict[i]["inputDimension"], transform_dict[i]["outputDimension"]
1112+
data_type, transform_type["inputDimension"], transform_type["outputDimension"]
10831113
].New()
10841114
else:
10851115
transform = transform_template[
1086-
data_type, transform_dict[i]["inputDimension"]
1116+
data_type, transform_type["inputDimension"]
10871117
].New()
10881118

10891119
transform.SetObjectName(transform_dict[i]["name"])
@@ -1102,8 +1132,8 @@ def special_transform_check(transform_name):
11021132
if len(transforms_list) > 1:
11031133
# Create a Composite Transform object
11041134
# and add all the transforms in it.
1105-
data_type = parametersValueType_dict[transform_dict[0]["parametersValueType"]]
1106-
transform = itk.CompositeTransform[data_type, transforms_list[0]['inputDimension']].New()
1135+
data_type = parametersValueType_dict[transform_dict[0]["transformType"]["parametersValueType"]]
1136+
transform = itk.CompositeTransform[data_type, transforms_list[0]["transformType"]['inputDimension']].New()
11071137
for current_transform in transforms_list:
11081138
transform.AddTransform(current_transform)
11091139
else:

0 commit comments

Comments
 (0)