22
33from abc import ABCMeta , abstractmethod
44from ..utils import check_consistency
5- from ..domain import DomainInterface
5+ from ..domain import DomainInterface , CartesianDomain
66from ..condition .domain_equation_condition import DomainEquationCondition
77from ..condition import InputPointsEquationCondition
88from copy import deepcopy
9- from pina import LabelTensor
9+ from .. import LabelTensor
10+ from ..utils import merge_tensors
1011
1112
1213class AbstractProblem (metaclass = ABCMeta ):
@@ -21,7 +22,7 @@ class AbstractProblem(metaclass=ABCMeta):
2122
2223 def __init__ (self ):
2324
24- self .discretised_domains = {}
25+ self ._discretised_domains = {}
2526 # create collector to manage problem data
2627
2728 # create hook conditions <-> problems
@@ -53,6 +54,10 @@ def batching_dimension(self):
5354 def batching_dimension (self , value ):
5455 self ._batching_dimension = value
5556
57+ @property
58+ def discretised_domains (self ):
59+ return self ._discretised_domains
60+
5661 # TODO this should be erase when dataloading will interface collector,
5762 # kept only for back compatibility
5863 @property
@@ -62,7 +67,7 @@ def input_pts(self):
6267 if hasattr (cond , "input_points" ):
6368 to_return [cond_name ] = cond .input_points
6469 elif hasattr (cond , "domain" ):
65- to_return [cond_name ] = self .discretised_domains [cond .domain ]
70+ to_return [cond_name ] = self ._discretised_domains [cond .domain ]
6671 return to_return
6772
6873 def __deepcopy__ (self , memo ):
@@ -139,9 +144,10 @@ def conditions(self):
139144 return self .conditions
140145
141146 def discretise_domain (self ,
142- n ,
147+ n = None ,
143148 mode = "random" ,
144- domains = "all" ):
149+ domains = "all" ,
150+ sample_rules = None ):
145151 """
146152 Generate a set of points to span the `Location` of all the conditions of
147153 the problem.
@@ -153,6 +159,8 @@ def discretise_domain(self,
153159 Available modes include: random sampling, ``random``;
154160 latin hypercube sampling, ``latin`` or ``lh``;
155161 chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
162+ :param variables: variable(s) to sample, defaults to 'all'.
163+ :type variables: str | list[str]
156164 :param domains: problem's domain from where to sample, defaults to 'all'.
157165 :type domains: str | list[str]
158166
@@ -170,24 +178,55 @@ def discretise_domain(self,
170178 """
171179
172180 # check consistecy n, mode, variables, locations
173- check_consistency (n , int )
174- check_consistency (mode , str )
181+ if sample_rules is not None :
182+ check_consistency (sample_rules , dict )
183+ if mode is not None :
184+ check_consistency (mode , str )
175185 check_consistency (domains , (list , str ))
176186
177- # check correct sampling mode
178- # if mode not in DomainInterface.available_sampling_modes:
179- # raise TypeError(f"mode {mode} not valid.")
180-
181187 # check correct location
182188 if domains == "all" :
183189 domains = self .domains .keys ()
184190 elif not isinstance (domains , (list )):
185191 domains = [domains ]
192+ if n is not None and sample_rules is None :
193+ self ._apply_default_discretization (n , mode , domains )
194+ if n is None and sample_rules is not None :
195+ self ._apply_custom_discretization (sample_rules , domains )
196+ elif n is not None and sample_rules is not None :
197+ raise RuntimeError (
198+ "You can't specify both n and sample_rules at the same time."
199+ )
200+ elif n is None and sample_rules is None :
201+ raise RuntimeError (
202+ "You have to specify either n or sample_rules."
203+ )
186204
205+ def _apply_default_discretization (self , n , mode , domains ):
187206 for domain in domains :
188207 self .discretised_domains [domain ] = (
189- self .domains [domain ].sample (n , mode )
208+ self .domains [domain ].sample (n , mode ).sort_labels ()
209+ )
210+
211+ def _apply_custom_discretization (self , sample_rules , domains ):
212+ if sorted (list (sample_rules .keys ())) != sorted (self .input_variables ):
213+ raise RuntimeError (
214+ "The keys of the sample_rules dictionary must be the same as "
215+ "the input variables."
190216 )
217+ for domain in domains :
218+ if not isinstance (self .domains [domain ], CartesianDomain ):
219+ raise RuntimeError (
220+ "Custom discretisation can be applied only on Cartesian "
221+ "domains" )
222+ discretised_tensor = []
223+ for var , rules in sample_rules .items ():
224+ n , mode = rules ['n' ], rules ['mode' ]
225+ points = self .domains [domain ].sample (n , mode , var )
226+ discretised_tensor .append (points )
227+
228+ self .discretised_domains [domain ] = merge_tensors (
229+ discretised_tensor ).sort_labels ()
191230
192231 def add_points (self , new_points_dict ):
193232 """
0 commit comments