Skip to content

Commit f3dab12

Browse files
FilippoOlivondem0
authored andcommitted
Implement custom sampling logic
1 parent 3af8d6e commit f3dab12

File tree

4 files changed

+114
-46
lines changed

4 files changed

+114
-46
lines changed

pina/domain/cartesian.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,10 @@ def _1d_sampler(n, mode, variables):
160160
pts_variable.labels = [variable]
161161

162162
tmp.append(pts_variable)
163-
164-
result = tmp[0]
165-
for i in tmp[1:]:
166-
result = result.append(i, mode="cross")
163+
if tmp:
164+
result = tmp[0]
165+
for i in tmp[1:]:
166+
result = result.append(i, mode="cross")
167167

168168
for variable in variables:
169169
if variable in self.fixed_.keys():
@@ -242,6 +242,8 @@ def _single_points_sample(n, variables):
242242

243243
if self.fixed_ and (not self.range_):
244244
return _single_points_sample(n, variables)
245+
if isinstance(variables, str) and variables in self.fixed_.keys():
246+
return _single_points_sample(n, variables)
245247

246248
if mode in ["grid", "chebyshev"]:
247249
return _1d_sampler(n, mode, variables).extract(variables)

pina/problem/abstract_problem.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
from abc import ABCMeta, abstractmethod
44
from ..utils import check_consistency
5-
from ..domain import DomainInterface
5+
from ..domain import DomainInterface, CartesianDomain
66
from ..condition.domain_equation_condition import DomainEquationCondition
77
from ..condition import InputPointsEquationCondition
88
from copy import deepcopy
9-
from pina import LabelTensor
9+
from .. import LabelTensor
10+
from ..utils import merge_tensors
1011

1112

1213
class AbstractProblem(metaclass=ABCMeta):
@@ -21,7 +22,7 @@ class AbstractProblem(metaclass=ABCMeta):
2122

2223
def __init__(self):
2324

24-
self.discretised_domains = {}
25+
self._discretised_domains = {}
2526
# create collector to manage problem data
2627

2728
# create hook conditions <-> problems
@@ -53,6 +54,10 @@ def batching_dimension(self):
5354
def batching_dimension(self, value):
5455
self._batching_dimension = value
5556

57+
@property
58+
def discretised_domains(self):
59+
return self._discretised_domains
60+
5661
# TODO this should be erase when dataloading will interface collector,
5762
# kept only for back compatibility
5863
@property
@@ -62,7 +67,7 @@ def input_pts(self):
6267
if hasattr(cond, "input_points"):
6368
to_return[cond_name] = cond.input_points
6469
elif hasattr(cond, "domain"):
65-
to_return[cond_name] = self.discretised_domains[cond.domain]
70+
to_return[cond_name] = self._discretised_domains[cond.domain]
6671
return to_return
6772

6873
def __deepcopy__(self, memo):
@@ -139,9 +144,10 @@ def conditions(self):
139144
return self.conditions
140145

141146
def discretise_domain(self,
142-
n,
147+
n=None,
143148
mode="random",
144-
domains="all"):
149+
domains="all",
150+
sample_rules=None):
145151
"""
146152
Generate a set of points to span the `Location` of all the conditions of
147153
the problem.
@@ -153,6 +159,8 @@ def discretise_domain(self,
153159
Available modes include: random sampling, ``random``;
154160
latin hypercube sampling, ``latin`` or ``lh``;
155161
chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
162+
:param variables: variable(s) to sample, defaults to 'all'.
163+
:type variables: str | list[str]
156164
:param domains: problem's domain from where to sample, defaults to 'all'.
157165
:type domains: str | list[str]
158166
@@ -170,24 +178,55 @@ def discretise_domain(self,
170178
"""
171179

172180
# check consistecy n, mode, variables, locations
173-
check_consistency(n, int)
174-
check_consistency(mode, str)
181+
if sample_rules is not None:
182+
check_consistency(sample_rules, dict)
183+
if mode is not None:
184+
check_consistency(mode, str)
175185
check_consistency(domains, (list, str))
176186

177-
# check correct sampling mode
178-
# if mode not in DomainInterface.available_sampling_modes:
179-
# raise TypeError(f"mode {mode} not valid.")
180-
181187
# check correct location
182188
if domains == "all":
183189
domains = self.domains.keys()
184190
elif not isinstance(domains, (list)):
185191
domains = [domains]
192+
if n is not None and sample_rules is None:
193+
self._apply_default_discretization(n, mode, domains)
194+
if n is None and sample_rules is not None:
195+
self._apply_custom_discretization(sample_rules, domains)
196+
elif n is not None and sample_rules is not None:
197+
raise RuntimeError(
198+
"You can't specify both n and sample_rules at the same time."
199+
)
200+
elif n is None and sample_rules is None:
201+
raise RuntimeError(
202+
"You have to specify either n or sample_rules."
203+
)
186204

205+
def _apply_default_discretization(self, n, mode, domains):
187206
for domain in domains:
188207
self.discretised_domains[domain] = (
189-
self.domains[domain].sample(n, mode)
208+
self.domains[domain].sample(n, mode).sort_labels()
209+
)
210+
211+
def _apply_custom_discretization(self, sample_rules, domains):
212+
if sorted(list(sample_rules.keys())) != sorted(self.input_variables):
213+
raise RuntimeError(
214+
"The keys of the sample_rules dictionary must be the same as "
215+
"the input variables."
190216
)
217+
for domain in domains:
218+
if not isinstance(self.domains[domain], CartesianDomain):
219+
raise RuntimeError(
220+
"Custom discretisation can be applied only on Cartesian "
221+
"domains")
222+
discretised_tensor = []
223+
for var, rules in sample_rules.items():
224+
n, mode = rules['n'], rules['mode']
225+
points = self.domains[domain].sample(n, mode, var)
226+
discretised_tensor.append(points)
227+
228+
self.discretised_domains[domain] = merge_tensors(
229+
discretised_tensor).sort_labels()
191230

192231
def add_points(self, new_points_dict):
193232
"""

tests/test_collector.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,23 @@
1111
from pina.collector import Collector
1212

1313

14-
# def test_supervised_tensor_collector():
15-
# class SupervisedProblem(AbstractProblem):
16-
# output_variables = None
17-
# conditions = {
18-
# 'data1' : Condition(input_points=torch.rand((10,2)),
19-
# output_points=torch.rand((10,2))),
20-
# 'data2' : Condition(input_points=torch.rand((20,2)),
21-
# output_points=torch.rand((20,2))),
22-
# 'data3' : Condition(input_points=torch.rand((30,2)),
23-
# output_points=torch.rand((30,2))),
24-
# }
25-
# problem = SupervisedProblem()
26-
# collector = Collector(problem)
27-
# for v in collector.conditions_name.values():
28-
# assert v in problem.conditions.keys()
29-
# assert all(collector._is_conditions_ready.values())
14+
def test_supervised_tensor_collector():
15+
class SupervisedProblem(AbstractProblem):
16+
output_variables = None
17+
conditions = {
18+
'data1': Condition(input_points=torch.rand((10, 2)),
19+
output_points=torch.rand((10, 2))),
20+
'data2': Condition(input_points=torch.rand((20, 2)),
21+
output_points=torch.rand((20, 2))),
22+
'data3': Condition(input_points=torch.rand((30, 2)),
23+
output_points=torch.rand((30, 2))),
24+
}
25+
26+
problem = SupervisedProblem()
27+
collector = Collector(problem)
28+
for v in collector.conditions_name.values():
29+
assert v in problem.conditions.keys()
30+
3031

3132
def test_pinn_collector():
3233
def laplace_equation(input_, output_):
@@ -81,7 +82,7 @@ class Poisson(SpatialProblem):
8182
def poisson_sol(self, pts):
8283
return -(torch.sin(pts.extract(['x']) * torch.pi) *
8384
torch.sin(pts.extract(['y']) * torch.pi)) / (
84-
2 * torch.pi ** 2)
85+
2 * torch.pi ** 2)
8586

8687
truth_solution = poisson_sol
8788

tests/test_problem.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import pytest
33
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
44
from pina import LabelTensor
5+
from pina.domain import Union
6+
from pina.domain import CartesianDomain
57

68

79
def test_discretise_domain():
@@ -29,18 +31,6 @@ def test_discretise_domain():
2931
poisson_problem.discretise_domain(n)
3032

3133

32-
'''
33-
def test_sampling_few_variables():
34-
n = 10
35-
poisson_problem = Poisson()
36-
poisson_problem.discretise_domain(n,
37-
'grid',
38-
domains=['D'],
39-
variables=['x'])
40-
assert poisson_problem.discretised_domains['D'].shape[1] == 1
41-
'''
42-
43-
4434
def test_variables_correct_order_sampling():
4535
n = 10
4636
poisson_problem = Poisson()
@@ -66,3 +56,39 @@ def test_add_points():
6656
new_pts.extract('x'))
6757
assert torch.isclose(poisson_problem.discretised_domains['D'].extract('y'),
6858
new_pts.extract('y'))
59+
60+
@pytest.mark.parametrize(
61+
"mode",
62+
[
63+
'random',
64+
'grid'
65+
]
66+
)
67+
def test_custom_sampling_logic(mode):
68+
poisson_problem = Poisson()
69+
sampling_rules = {
70+
'x': {'n': 100, 'mode': mode},
71+
'y': {'n': 50, 'mode': mode}
72+
}
73+
poisson_problem.discretise_domain(sample_rules=sampling_rules)
74+
for domain in ['g1', 'g2', 'g3', 'g4']:
75+
assert poisson_problem.discretised_domains[domain].shape[0] == 100 * 50
76+
assert poisson_problem.discretised_domains[domain].labels == ['x', 'y']
77+
78+
@pytest.mark.parametrize(
79+
"mode",
80+
[
81+
'random',
82+
'grid'
83+
]
84+
)
85+
def test_wrong_custom_sampling_logic(mode):
86+
d2 = CartesianDomain({'x': [1,2], 'y': [0,1] })
87+
poisson_problem = Poisson()
88+
poisson_problem.domains['D'] = Union([poisson_problem.domains['D'], d2])
89+
sampling_rules = {
90+
'x': {'n': 100, 'mode': mode},
91+
'y': {'n': 50, 'mode': mode}
92+
}
93+
with pytest.raises(RuntimeError):
94+
poisson_problem.discretise_domain(sample_rules=sampling_rules)

0 commit comments

Comments
 (0)