-
Notifications
You must be signed in to change notification settings - Fork 92
Closed
Labels
Description
Describe the bug
First of this is not an urgent matter since after the error the plot is still shown correctly.
The function plotter.plot_samples() results in an error when your problem definition includes a Condition using data.
Problem definition
class Heat2D(TimeDependentProblem, SpatialProblem):
# Define these yourself
LENGTH_X = 82
LENGTH_Y = 70
DURATION = 2284
output_variables = ['u']
spatial_domain = Span({'x': [0, LENGTH_X], 'y': [0, LENGTH_Y]})
temporal_domain = Span({'t': [0, DURATION]})
def heat_equation_2D(input_, output_):
'''1'''
# c is thermal diffusivity, variates for different materials so google
c = (0.01/torch.pi) ** 0.5
du = grad(output_, input_)
ddu = grad(du, input_, components=['dudx','dudy'])
return (
du.extract(['dudt']) -
(c**2)*(ddu.extract(['ddudxdx']) + ddu.extract(['ddudydy']))
)
def nil_dirichlet_x(input_, output_):
'''2 and 3'''
du = grad(output_, input_)
u_expected_boundary = 0.0
return du.extract(['dudx']) - u_expected_boundary
def nil_dirichlet_yL(input_, output_):
'''4'''
du = grad(output_, input_)
u_expected_boundary = 0.0
return du.extract(['dudy']) - u_expected_boundary
def nil_dirichlet_y0(input_, output_):
'''5'''
# TODO: make this conditionally on if door is open
du = grad(output_, input_)
u_expected_boundary = 0.0
return du.extract(['dudy']) - u_expected_boundary
def initial_condition(input_, output_):
'''6'''
u_expected_initial = torch.sin(torch.pi*input_.extract(['x']))
return output_.extract(['u']) - u_expected_initial
conditions = {
'boundx0': Condition(location=Span({'x': 0, 'y': [0, LENGTH_Y], 't': [0, DURATION]}), function=nil_dirichlet_x),
'boundxL': Condition(location=Span({'x': LENGTH_X, 'y': [0, LENGTH_Y], 't': [0, DURATION]}), function=nil_dirichlet_x),
'boundy0': Condition(location=Span({'x': [0, LENGTH_X], 'y': 0, 't': [0, DURATION]}), function=nil_dirichlet_y0),
'boundyL': Condition(location=Span({'x': [0, LENGTH_X], 'y': LENGTH_Y, 't': [0, DURATION]}), function=nil_dirichlet_yL),
'initial': Condition(location=Span({'x': [0, LENGTH_X], 'y': [0, LENGTH_Y], 't': 0}), function=initial_condition),
'heat_eq': Condition(location=Span({'x': [0, LENGTH_X], 'y': [0, LENGTH_Y], 't': [0, DURATION]}), function=heat_equation_2D),
'data': Condition(input_points=X_input_tensor , output_points=X_output_tensor),
}
Pinn:
class myFeature(torch.nn.Module):
#TODO
"""
Feature: sin(pi*x)
"""
def __init__(self, idx):
super(myFeature, self).__init__()
self.idx = idx
def forward(self, x):
return LabelTensor(torch.sin(torch.pi * x.extract(['x'])), ['sin(x)'])
heat_problem = Heat2D()
model = FeedForward(
layers=[30, 20, 10, 5],
output_variables=heat_problem.output_variables,
input_variables=heat_problem.input_variables,
func=Softplus,
extra_features=[myFeature(0)],
)
pinn = PINN(
heat_problem,
model,
lr=0.01,
error_norm='mse',
regularizer=0)
pinn.span_pts(
{'n': 10, 'mode': 'grid', 'variables': 't'},
{'n': 10, 'mode': 'grid', 'variables': ['x', 'y']},
locations=['heat_eq'])
pinn.span_pts(20, 'random', locations=['boundx0', 'boundxL', 'initial', 'boundyL', 'boundy0'])
pinn.train(1000, 100)
and
print(pinn.input_pts.keys())
gives
dict_keys(['heat_eq', 'boundx0', 'boundxL', 'initial', 'boundyL', 'boundy0', 'data'])
and when plotting
# plot samples
plotter = Plotter()
plotter.plot_samples(pinn=pinn)
I get
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[58], line 3
1 # plot samples
2 plotter = Plotter()
----> 3 plotter.plot_samples(pinn=pinn)
File ~/opt/anaconda3/lib/python3.8/site-packages/pina/plotter.py:46, in Plotter.plot_samples(self, pinn, variables)
44 ax = fig.add_subplot(projection=proj)
45 for location in pinn.input_pts:
---> 46 coords = pinn.input_pts[location].extract(variables).T.detach()
47 if coords.shape[0] == 1: # 1D samples
48 ax.plot(coords[0], torch.zeros(coords[0].shape), '.',
49 label=location)
AttributeError: 'Condition' object has no attribute 'extract'
followed by a plot showing all the function samples correctly.
Expected behavior
I think the plot_samples method should be checking whether the location is an input-output condition before trying to plot it