@@ -981,57 +981,81 @@ def dict_from_pointset(pointset: "itkt.PointSet") -> Dict:
981981 )
982982
983983
984- def dict_from_transform (transform : "itkt.TransformBase" ) -> Dict :
984+ def dict_from_transform (transform : Union ["itkt.TransformBase" , List ["itkt.TransformBase" ]]) -> Union [List [Dict ], Dict ]:
985+ """Serialize a Python itk.Transform object to a pickable Python dictionary.
986+
987+ If the transform is a list of transforms, then a list of dictionaries is returned.
988+ If the transform is a single, non-Composite transform, then a single dictionary is returned.
989+ Composite transforms and nested composite transforms are flattened into a list of dictionaries.
990+ """
985991 import itk
992+ datatype_dict = {"double" : itk .D , "float" : itk .F }
986993
987994 def update_transform_dict (current_transform ):
988995 current_transform_type = current_transform .GetTransformTypeAsString ()
989996 current_transform_type_split = current_transform_type .split ("_" )
990- component = itk .template (current_transform )
991997
992- in_transform_dict = dict ()
993- in_transform_dict ["name" ] = current_transform .GetObjectName ()
998+ transform_type = dict ()
999+ transform_parameterization = current_transform_type_split [0 ].replace ("Transform" , "" )
1000+ transform_type ["transformParameterization" ] = transform_parameterization
9941001
995- datatype_dict = {"double" : itk .D , "float" : itk .F }
996- in_transform_dict ["parametersValueType" ] = python_to_js (
1002+ transform_type ["parametersValueType" ] = python_to_js (
9971003 datatype_dict [current_transform_type_split [1 ]]
9981004 )
999- in_transform_dict ["inputDimension" ] = int (current_transform_type_split [2 ])
1000- in_transform_dict ["outputDimension" ] = int (current_transform_type_split [3 ])
1001- in_transform_dict ["transformType" ] = current_transform_type_split [0 ]
1005+ transform_type ["inputDimension" ] = int (current_transform_type_split [2 ])
1006+ transform_type ["outputDimension" ] = int (current_transform_type_split [3 ])
10021007
1003- in_transform_dict ["inputSpaceName" ] = current_transform .GetInputSpaceName ()
1004- in_transform_dict ["outputSpaceName" ] = current_transform .GetOutputSpaceName ()
1008+ transform_dict = dict ()
1009+ transform_dict ['transformType' ] = transform_type
1010+ transform_dict ["name" ] = current_transform .GetObjectName ()
1011+
1012+ transform_dict ["inputSpaceName" ] = current_transform .GetInputSpaceName ()
1013+ transform_dict ["outputSpaceName" ] = current_transform .GetOutputSpaceName ()
10051014
10061015 # To avoid copying the parameters for the Composite Transform
10071016 # as it is a copy of child transforms.
10081017 if "Composite" not in current_transform_type_split [0 ]:
10091018 p = np .array (current_transform .GetParameters ())
1010- in_transform_dict ["parameters" ] = p
1019+ transform_dict ["parameters" ] = p
10111020
10121021 fp = np .array (current_transform .GetFixedParameters ())
1013- in_transform_dict ["fixedParameters" ] = fp
1022+ transform_dict ["fixedParameters" ] = fp
10141023
1015- in_transform_dict ["numberOfParameters" ] = p .shape [0 ]
1016- in_transform_dict ["numberOfFixedParameters" ] = fp .shape [0 ]
1024+ transform_dict ["numberOfParameters" ] = p .shape [0 ]
1025+ transform_dict ["numberOfFixedParameters" ] = fp .shape [0 ]
10171026
1018- return in_transform_dict
1027+ return transform_dict
10191028
10201029 dict_array = []
1021- transform_type = transform .GetTransformTypeAsString ()
1022- if "CompositeTransform" in transform_type :
1023- # Add the transforms inside the composite transform
1024- # range is over-ridden so using this hack to create a list
1025- for i , _ in enumerate ([0 ] * transform .GetNumberOfTransforms ()):
1026- current_transform = transform .GetNthTransform (i )
1027- dict_array .append (update_transform_dict (current_transform ))
1030+ multi = False
1031+ def add_transform_dict (transform ):
1032+ transform_type = transform .GetTransformTypeAsString ()
1033+ if "CompositeTransform" in transform_type :
1034+ # Add the transforms inside the composite transform
1035+ # range is over-ridden so using this hack to create a list
1036+ for i , _ in enumerate ([0 ] * transform .GetNumberOfTransforms ()):
1037+ current_transform = transform .GetNthTransform (i )
1038+ dict_array .append (update_transform_dict (current_transform ))
1039+ return True
1040+ else :
1041+ dict_array .append (update_transform_dict (transform ))
1042+ return False
1043+ if isinstance (transform , list ):
1044+ multi = True
1045+ for t in transform :
1046+ add_transform_dict (t )
10281047 else :
1029- dict_array . append ( update_transform_dict ( transform ) )
1048+ multi = add_transform_dict ( transform )
10301049
1031- return dict_array
1050+ if multi :
1051+ return dict_array
1052+ else :
1053+ return dict_array [0 ]
10321054
1055+ def transform_from_dict (transform_dict : Union [Dict , List [Dict ]]) -> "itkt.TransformBase" :
1056+ """Deserialize a dictionary representing an itk.Transform object.
10331057
1034- def transform_from_dict ( transform_dict : Dict ) -> "itkt.TransformBase" :
1058+ If the dictionary represents a list of transforms, then a Composite Transform is returned."""
10351059 import itk
10361060
10371061 def set_parameters (transform , transform_parameters , transform_fixed_parameters , data_type ):
@@ -1055,35 +1079,41 @@ def special_transform_check(transform_name):
10551079
10561080 parametersValueType_dict = {"float32" : itk .F , "float64" : itk .D }
10571081
1082+ if not isinstance (transform_dict , list ):
1083+ transform_dict = [transform_dict ]
1084+
10581085 # Loop over all the transforms in the dictionary
10591086 transforms_list = []
10601087 for i , _ in enumerate (transform_dict ):
1061- data_type = parametersValueType_dict [transform_dict [i ]["parametersValueType" ]]
1088+ transform_type = transform_dict [i ]["transformType" ]
1089+ data_type = parametersValueType_dict [transform_type ["parametersValueType" ]]
1090+
1091+ transform_parameterization = transform_type ["transformParameterization" ] + 'Transform'
10621092
10631093 # No template parameter needed for transforms having 2D or 3D name
10641094 # Also for some selected transforms
1065- if special_transform_check (transform_dict [ i ][ "transformType" ] ):
1066- transform_template = getattr (itk , transform_dict [ i ][ "transformType" ] )
1095+ if special_transform_check (transform_parameterization ):
1096+ transform_template = getattr (itk , transform_parameterization )
10671097 transform = transform_template [data_type ].New ()
10681098 # Currently only BSpline Transform has 3 template parameters
10691099 # For future extensions the information will have to be encoded in
10701100 # the transformType variable. The transform object once added in a
10711101 # composite transform lose the information for other template parameters ex. BSpline.
10721102 # The Spline order is fixed as 3 here.
1073- elif transform_dict [ i ][ "transformType" ] == "BSplineTransform" :
1074- transform_template = getattr (itk , transform_dict [ i ][ "transformType" ] )
1103+ elif transform_parameterization == "BSplineTransform" :
1104+ transform_template = getattr (itk , transform_parameterization )
10751105 transform = transform_template [
1076- data_type , transform_dict [ i ] ["inputDimension" ], 3
1106+ data_type , transform_type ["inputDimension" ], 3
10771107 ].New ()
10781108 else :
1079- transform_template = getattr (itk , transform_dict [ i ][ "transformType" ] )
1109+ transform_template = getattr (itk , transform_parameterization )
10801110 if len (transform_template .items ()[0 ][0 ]) > 2 :
10811111 transform = transform_template [
1082- data_type , transform_dict [ i ][ "inputDimension" ], transform_dict [ i ] ["outputDimension" ]
1112+ data_type , transform_type [ "inputDimension" ], transform_type ["outputDimension" ]
10831113 ].New ()
10841114 else :
10851115 transform = transform_template [
1086- data_type , transform_dict [ i ] ["inputDimension" ]
1116+ data_type , transform_type ["inputDimension" ]
10871117 ].New ()
10881118
10891119 transform .SetObjectName (transform_dict [i ]["name" ])
@@ -1102,8 +1132,8 @@ def special_transform_check(transform_name):
11021132 if len (transforms_list ) > 1 :
11031133 # Create a Composite Transform object
11041134 # and add all the transforms in it.
1105- data_type = parametersValueType_dict [transform_dict [0 ]["parametersValueType" ]]
1106- transform = itk .CompositeTransform [data_type , transforms_list [0 ]['inputDimension' ]].New ()
1135+ data_type = parametersValueType_dict [transform_dict [0 ]["transformType" ][ " parametersValueType" ]]
1136+ transform = itk .CompositeTransform [data_type , transforms_list [0 ]["transformType" ][ 'inputDimension' ]].New ()
11071137 for current_transform in transforms_list :
11081138 transform .AddTransform (current_transform )
11091139 else :
0 commit comments