44import requests
55import torch
66from io import BytesIO
7- from requests .exceptions import RequestException
87from ... import Condition
98from ... import LabelTensor
109from ...operator import laplacian
1110from ...domain import CartesianDomain
1211from ...equation import Equation , FixedValue
1312from ...problem import SpatialProblem , InverseProblem
14- from ...utils import custom_warning_format
13+ from ...utils import custom_warning_format , check_consistency
1514
1615warnings .formatwarning = custom_warning_format
1716warnings .filterwarnings ("always" , category = ResourceWarning )
1817
1918
20- def _load_tensor_from_url (url , labels ):
19+ def _load_tensor_from_url (url , labels , timeout = 10 ):
2120 """
2221 Downloads a tensor file from a URL and wraps it in a LabelTensor.
2322
@@ -28,21 +27,24 @@ def _load_tensor_from_url(url, labels):
2827
2928 :param str url: URL to the remote `.pth` tensor file.
3029 :param list[str] | tuple[str] labels: Labels for the resulting LabelTensor.
30+ :param int timeout: Timeout for the request in seconds.
3131 :return: A LabelTensor object if successful, otherwise None.
3232 :rtype: LabelTensor | None
3333 """
34+ # Try to download the tensor file from the given URL
3435 try :
35- response = requests .get (url )
36+ response = requests .get (url , timeout = timeout )
3637 response .raise_for_status ()
3738 tensor = torch .load (
3839 BytesIO (response .content ), weights_only = False
3940 ).tensor .detach ()
4041 return LabelTensor (tensor , labels )
41- except RequestException as e :
42- print (
43- "Could not download data for 'InversePoisson2DSquareProblem' "
44- f"from '{ url } '. "
45- f"Reason: { e } . Skipping data loading." ,
42+
43+ # If the request fails, issue a warning and return None
44+ except requests .exceptions .RequestException as e :
45+ warnings .warn (
46+ f"Could not download data for 'InversePoisson2DSquareProblem' "
47+ f"from '{ url } '. Reason: { e } . Skipping data loading." ,
4648 ResourceWarning ,
4749 )
4850 return None
@@ -66,19 +68,6 @@ def laplace_equation(input_, output_, params_):
6668 return delta_u - force_term
6769
6870
69- # loading data
70- input_url = (
71- "https://github.com/mathLab/PINA/raw/refs/heads/master"
72- "/tutorials/tutorial7/data/pts_0.5_0.5"
73- )
74- output_url = (
75- "https://github.com/mathLab/PINA/raw/refs/heads/master"
76- "/tutorials/tutorial7/data/pinn_solution_0.5_0.5"
77- )
78- input_data = _load_tensor_from_url (input_url , ["x" , "y" , "mu1" , "mu2" ])
79- output_data = _load_tensor_from_url (output_url , ["u" ])
80-
81-
8271class InversePoisson2DSquareProblem (SpatialProblem , InverseProblem ):
8372 r"""
8473 Implementation of the inverse 2-dimensional Poisson problem in the square
@@ -113,5 +102,50 @@ class InversePoisson2DSquareProblem(SpatialProblem, InverseProblem):
113102 "D" : Condition (domain = "D" , equation = Equation (laplace_equation )),
114103 }
115104
116- if input_data is not None and input_data is not None :
117- conditions ["data" ] = Condition (input = input_data , target = output_data )
105+ def __init__ (self , load = True , data_size = 1.0 ):
106+ """
107+ Initialization of the :class:`InversePoisson2DSquareProblem`.
108+
109+ :param bool load: If True, it attempts to load data from remote URLs.
110+ Set to False to skip data loading (e.g., if no internet connection).
111+ :param float data_size: The fraction of the total data to use for the
112+ "data" condition. If set to 1.0, all available data is used.
113+ If set to 0.0, no data is used. Default is 1.0.
114+ :raises ValueError: If `data_size` is not in the range [0.0, 1.0].
115+ :raises ValueError: If `data_size` is not a float.
116+ """
117+ super ().__init__ ()
118+
119+ # Check consistency
120+ check_consistency (load , bool )
121+ check_consistency (data_size , float )
122+ if not 0.0 <= data_size <= 1.0 :
123+ raise ValueError (
124+ f"data_size must be in the range [0.0, 1.0], got { data_size } ."
125+ )
126+
127+ # Load data if requested
128+ if load :
129+
130+ # Define URLs for input and output data
131+ input_url = (
132+ "https://github.com/mathLab/PINA/raw/refs/heads/master"
133+ "/tutorials/tutorial7/data/pts_0.5_0.5"
134+ )
135+ output_url = (
136+ "https://github.com/mathLab/PINA/raw/refs/heads/master"
137+ "/tutorials/tutorial7/data/pinn_solution_0.5_0.5"
138+ )
139+
140+ # Define input and output data
141+ input_data = _load_tensor_from_url (
142+ input_url , ["x" , "y" , "mu1" , "mu2" ]
143+ )
144+ output_data = _load_tensor_from_url (output_url , ["u" ])
145+
146+ # Add the "data" condition
147+ if input_data is not None and output_data is not None :
148+ n_data = int (input_data .shape [0 ] * data_size )
149+ self .conditions ["data" ] = Condition (
150+ input = input_data [:n_data ], target = output_data [:n_data ]
151+ )
0 commit comments