Skip to content

Commit 24d806b

Browse files
GiovanniCanalidario-coscia
authored andcommitted
move equation attributes to correct device
1 parent 9c3e55d commit 24d806b

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

pina/equation/equation.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Module for the Equation."""
22

33
import inspect
4-
4+
import torch
55
from .equation_interface import EquationInterface
66

77

@@ -49,6 +49,10 @@ def residual(self, input_, output_, params_=None):
4949
:raises RuntimeError: If the underlying equation signature length is not
5050
2 (direct problem) or 3 (inverse problem).
5151
"""
52+
# Move the equation to the input_ device
53+
self.to(input_.device)
54+
55+
# Call the underlying equation based on its signature length
5256
if self.__len_sig == 2:
5357
return self.__equation(input_, output_)
5458
if self.__len_sig == 3:
@@ -57,3 +61,33 @@ def residual(self, input_, output_, params_=None):
5761
f"Unexpected number of arguments in equation: {self.__len_sig}. "
5862
"Expected either 2 (direct problem) or 3 (inverse problem)."
5963
)
64+
65+
def to(self, device):
66+
"""
67+
Move all tensor attributes of the Equation to the specified device.
68+
69+
:param torch.device device: The target device to move the tensors to.
70+
:return: The Equation instance moved to the specified device.
71+
:rtype: Equation
72+
"""
73+
# Iterate over all attributes of the Equation
74+
for key, val in self.__dict__.items():
75+
76+
# Move tensors in dictionaries to the specified device
77+
if isinstance(val, dict):
78+
self.__dict__[key] = {
79+
k: v.to(device) if torch.is_tensor(v) else v
80+
for k, v in val.items()
81+
}
82+
83+
# Move tensors in lists to the specified device
84+
elif isinstance(val, list):
85+
self.__dict__[key] = [
86+
v.to(device) if torch.is_tensor(v) else v for v in val
87+
]
88+
89+
# Move tensor attributes to the specified device
90+
elif torch.is_tensor(val):
91+
self.__dict__[key] = val.to(device)
92+
93+
return self

0 commit comments

Comments
 (0)