@@ -477,34 +477,12 @@ def forward(
477477
478478 # convert systems from engine to model units
479479 with record_function ("AtomisticModel::convert_units_input" ):
480- if self ._capabilities .length_unit != options .length_unit :
481- conversion = unit_conversion_factor (
482- quantity = "length" ,
483- from_unit = options .length_unit ,
484- to_unit = self ._capabilities .length_unit ,
485- )
486-
487- systems = _convert_systems_units (
488- systems ,
489- conversion ,
490- model_length_unit = self ._capabilities .length_unit ,
491- system_length_unit = options .length_unit ,
492- )
493-
494- for name , option in self ._requested_inputs .items ():
495- system_unit = str (
496- systems [0 ].get_data (name ).get_info ("unit" )
497- ) # For torchscript
498- to_unit = option .unit
499- conversion = unit_conversion_factor (
500- quantity = option .quantity ,
501- from_unit = system_unit ,
502- to_unit = to_unit ,
503- )
504-
505- _convert_systems_input_units (
506- systems , option .quantity , conversion , to_unit
507- )
480+ systems = _convert_systems_units (
481+ systems ,
482+ model_length_unit = self ._capabilities .length_unit ,
483+ system_length_unit = options .length_unit ,
484+ requested_inputs = self ._requested_inputs ,
485+ )
508486
509487 # run the actual calculations
510488 with record_function ("Model::forward" ):
@@ -948,12 +926,19 @@ def _check_inputs(
948926
949927def _convert_systems_units (
950928 systems : List [System ],
951- conversion : float ,
952929 model_length_unit : str ,
953930 system_length_unit : str ,
931+ requested_inputs : Dict [str , ModelOutput ],
954932) -> List [System ]:
955- if conversion == 1.0 :
956- return systems
933+ if model_length_unit == "" or system_length_unit == "" :
934+ # no conversion for positions/cell/NL
935+ conversion = 1.0
936+ else :
937+ conversion = unit_conversion_factor (
938+ quantity = "length" ,
939+ from_unit = system_length_unit ,
940+ to_unit = model_length_unit ,
941+ )
957942
958943 new_systems : List [System ] = []
959944 for system in systems :
@@ -978,41 +963,58 @@ def _convert_systems_units(
978963 )
979964
980965 known_data = system .known_data ()
981- if len (known_data ) != 0 :
982- warnings .warn (
983- "the model requires a different length unit "
984- f"({ model_length_unit } ) than the system ({ system_length_unit } ), "
985- f"but we don't know how to convert custom data ({ known_data } ) "
986- "accordingly" ,
987- stacklevel = 2 ,
988- )
966+ for name in known_data :
967+ if name not in requested_inputs :
968+ # not a requested input, just copy as is
969+ new_system .add_data (name , system .get_data (name ))
989970
990- for data in known_data :
991- new_system .add_data (data , system .get_data (data ))
971+ else :
972+ requested = requested_inputs [name ]
973+ tensor = system .get_data (name )
974+ unit = tensor .get_info ("unit" )
975+
976+ if requested .quantity != "" and unit is not None :
977+ conversion = unit_conversion_factor (
978+ quantity = requested .quantity ,
979+ from_unit = unit ,
980+ to_unit = requested .unit ,
981+ )
982+ else :
983+ conversion = 1.0
984+
985+ new_blocks : List [TensorBlock ] = []
986+ for block in tensor .blocks ():
987+ new_values = conversion * block .values
988+ new_block = TensorBlock (
989+ values = new_values ,
990+ samples = block .samples ,
991+ components = block .components ,
992+ properties = block .properties ,
993+ )
992994
993- new_systems .append (new_system )
995+ for parameter , gradient in block .gradients ():
996+ if len (gradient .gradients_list ()) != 0 :
997+ raise NotImplementedError (
998+ "nested gradients are not supported"
999+ )
1000+
1001+ new_gradient = TensorBlock (
1002+ values = conversion * gradient .values ,
1003+ samples = gradient .samples ,
1004+ components = gradient .components ,
1005+ properties = gradient .properties ,
1006+ )
1007+ new_block .add_gradient (parameter , new_gradient )
1008+ new_blocks .append (new_block )
9941009
995- return new_systems
1010+ new_tensor = TensorMap (
1011+ keys = tensor .keys ,
1012+ blocks = new_blocks ,
1013+ )
1014+ new_tensor .set_info ("unit" , requested .unit )
1015+ new_tensor .set_info ("quantity" , requested .quantity )
1016+ new_system .add_data (name , new_tensor )
9961017
1018+ new_systems .append (new_system )
9971019
998- def _convert_systems_input_units (
999- systems : List [System ], quantity : str , conversion : float , to_unit : str
1000- ) -> None :
1001- if conversion != 1.0 :
1002- for system in systems :
1003- tensor = system .get_data (quantity )
1004- tblock = tensor .block ()
1005- new_tensor = TensorMap (
1006- Labels ("_" , torch .tensor ([[0 ]])),
1007- [
1008- TensorBlock (
1009- values = conversion * tblock .values ,
1010- samples = tblock .samples ,
1011- components = tblock .components ,
1012- properties = tblock .properties ,
1013- )
1014- ],
1015- )
1016- new_tensor .set_info ("unit" , to_unit )
1017- new_tensor .set_info ("quantity" , quantity )
1018- system .add_data (quantity , new_tensor , override = True )
1020+ return new_systems
0 commit comments