Skip to content

plot_samples errors for Condition(input_points=..., output_points=...) in pinn #101

@Bovhasselt

Description

@Bovhasselt

Describe the bug
First of this is not an urgent matter since after the error the plot is still shown correctly.
The function plotter.plot_samples() results in an error when your problem definition includes a Condition using data.

Problem definition

class Heat2D(TimeDependentProblem, SpatialProblem):
    
    # Define these yourself
    LENGTH_X = 82
    LENGTH_Y = 70
    DURATION = 2284
    
    output_variables = ['u']
    spatial_domain = Span({'x': [0, LENGTH_X], 'y': [0, LENGTH_Y]})
    temporal_domain = Span({'t': [0, DURATION]})

    def heat_equation_2D(input_, output_):
        '''1'''
        # c is thermal diffusivity, variates for different materials so google
        c = (0.01/torch.pi) ** 0.5
        
        du = grad(output_, input_)
        ddu = grad(du, input_, components=['dudx','dudy'])
        return (
            du.extract(['dudt']) -
            (c**2)*(ddu.extract(['ddudxdx']) + ddu.extract(['ddudydy']))
        )

    def nil_dirichlet_x(input_, output_):
        '''2 and 3'''
        du = grad(output_, input_)
        u_expected_boundary = 0.0
        return du.extract(['dudx']) - u_expected_boundary
    
    def nil_dirichlet_yL(input_, output_):
        '''4'''
        du = grad(output_, input_)
        u_expected_boundary = 0.0
        return du.extract(['dudy']) - u_expected_boundary
    
    def nil_dirichlet_y0(input_, output_):
        '''5'''
        # TODO: make this conditionally on if door is open
        du = grad(output_, input_)
        u_expected_boundary = 0.0
        return du.extract(['dudy']) - u_expected_boundary

    def initial_condition(input_, output_):
        '''6'''
        u_expected_initial = torch.sin(torch.pi*input_.extract(['x']))
        return output_.extract(['u']) - u_expected_initial
    

    conditions = {
        'boundx0': Condition(location=Span({'x': 0, 'y': [0, LENGTH_Y], 't': [0, DURATION]}), function=nil_dirichlet_x),
        'boundxL': Condition(location=Span({'x': LENGTH_X, 'y': [0, LENGTH_Y], 't': [0, DURATION]}), function=nil_dirichlet_x),
        'boundy0': Condition(location=Span({'x': [0, LENGTH_X], 'y': 0, 't': [0, DURATION]}), function=nil_dirichlet_y0),
        'boundyL': Condition(location=Span({'x': [0, LENGTH_X], 'y': LENGTH_Y, 't': [0, DURATION]}), function=nil_dirichlet_yL),
        'initial': Condition(location=Span({'x': [0, LENGTH_X], 'y': [0, LENGTH_Y], 't': 0}), function=initial_condition),
        'heat_eq': Condition(location=Span({'x': [0, LENGTH_X], 'y': [0, LENGTH_Y], 't': [0, DURATION]}), function=heat_equation_2D),
        'data': Condition(input_points=X_input_tensor , output_points=X_output_tensor),
        }

Pinn:

class myFeature(torch.nn.Module):
    #TODO
    """
    Feature: sin(pi*x)
    """
    def __init__(self, idx):
        super(myFeature, self).__init__()
        self.idx = idx

    def forward(self, x):
        return LabelTensor(torch.sin(torch.pi * x.extract(['x'])), ['sin(x)'])

heat_problem = Heat2D()
model = FeedForward(
    layers=[30, 20, 10, 5],
    output_variables=heat_problem.output_variables,
    input_variables=heat_problem.input_variables,
    func=Softplus,
    extra_features=[myFeature(0)],
)

pinn = PINN(
    heat_problem,
    model,
    lr=0.01,
    error_norm='mse',
    regularizer=0)

pinn.span_pts(
    {'n': 10, 'mode': 'grid', 'variables': 't'},
    {'n': 10, 'mode': 'grid', 'variables': ['x', 'y']},
    locations=['heat_eq'])
pinn.span_pts(20, 'random', locations=['boundx0', 'boundxL', 'initial', 'boundyL', 'boundy0'])
pinn.train(1000, 100)

and

print(pinn.input_pts.keys())

gives

dict_keys(['heat_eq', 'boundx0', 'boundxL', 'initial', 'boundyL', 'boundy0', 'data'])

and when plotting

# plot samples
plotter = Plotter()
plotter.plot_samples(pinn=pinn)

I get

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[58], line 3
      1 # plot samples
      2 plotter = Plotter()
----> 3 plotter.plot_samples(pinn=pinn)

File ~/opt/anaconda3/lib/python3.8/site-packages/pina/plotter.py:46, in Plotter.plot_samples(self, pinn, variables)
     44 ax = fig.add_subplot(projection=proj)
     45 for location in pinn.input_pts:
---> 46     coords = pinn.input_pts[location].extract(variables).T.detach()
     47     if coords.shape[0] == 1:  # 1D samples
     48         ax.plot(coords[0], torch.zeros(coords[0].shape), '.',
     49                 label=location)

AttributeError: 'Condition' object has no attribute 'extract'

followed by a plot showing all the function samples correctly.

Expected behavior
I think the plot_samples method should be checking whether the location is an input-output condition before trying to plot it

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinglow priorityLow priority fix

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions