11"""Module for the Equation."""
22
33import inspect
4-
4+ import torch
55from .equation_interface import EquationInterface
66
77
@@ -49,6 +49,10 @@ def residual(self, input_, output_, params_=None):
4949 :raises RuntimeError: If the underlying equation signature length is not
5050 2 (direct problem) or 3 (inverse problem).
5151 """
52+ # Move the equation to the input_ device
53+ self .to (input_ .device )
54+
55+ # Call the underlying equation based on its signature length
5256 if self .__len_sig == 2 :
5357 return self .__equation (input_ , output_ )
5458 if self .__len_sig == 3 :
@@ -57,3 +61,33 @@ def residual(self, input_, output_, params_=None):
5761 f"Unexpected number of arguments in equation: { self .__len_sig } . "
5862 "Expected either 2 (direct problem) or 3 (inverse problem)."
5963 )
64+
65+ def to (self , device ):
66+ """
67+ Move all tensor attributes of the Equation to the specified device.
68+
69+ :param torch.device device: The target device to move the tensors to.
70+ :return: The Equation instance moved to the specified device.
71+ :rtype: Equation
72+ """
73+ # Iterate over all attributes of the Equation
74+ for key , val in self .__dict__ .items ():
75+
76+ # Move tensors in dictionaries to the specified device
77+ if isinstance (val , dict ):
78+ self .__dict__ [key ] = {
79+ k : v .to (device ) if torch .is_tensor (v ) else v
80+ for k , v in val .items ()
81+ }
82+
83+ # Move tensors in lists to the specified device
84+ elif isinstance (val , list ):
85+ self .__dict__ [key ] = [
86+ v .to (device ) if torch .is_tensor (v ) else v for v in val
87+ ]
88+
89+ # Move tensor attributes to the specified device
90+ elif torch .is_tensor (val ):
91+ self .__dict__ [key ] = val .to (device )
92+
93+ return self
0 commit comments