Skip to content

Bug in Weighting #741

@AleDinve

Description

@AleDinve

Describe the bug
When using weighting schemes such as self_adaptive_weighting, it considers all the initialized parameters, not just the trained ones (with actual gradients). This makes the code crashing when training models with unused or frozen weights.

To Reproduce

import torch
from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem
from pina.solver import PINN
from pina.loss import SelfAdaptiveWeighting
from pina import Trainer
from pina.optim import TorchOptimizer
import torch
from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem
from pina.solver import PINN
from pina.loss import SelfAdaptiveWeighting
from pina import Trainer
from pina.optim import TorchOptimizer


class Net(torch.nn.Module):
    def __init__(self, input_dim, output_dim, num_layers):
        super().__init__()
        self.mlp = FeedForward(
            input_dimensions=input_dim, 
            output_dimensions=output_dim, 
            layers=[10 for _ in range(num_layers)]
        )

    def forward(self, x):
        return self.mlp(x)

class Net_biased(torch.nn.Module):
    def __init__(self, input_dim, output_dim, num_layers):
        super().__init__()
        self.mlp = FeedForward(
            input_dimensions=input_dim, 
            output_dimensions=output_dim, 
            layers=[10 for _ in range(num_layers)]
        )
        self.bias = torch.nn.Parameter(torch.zeros(1))

    def forward(self, x):
        return self.mlp(x)
    

problem = Poisson2DSquareProblem()
problem.discretise_domain(10)
model = Net(2,1,2)
weighting = SelfAdaptiveWeighting()
solver = PINN(problem, 
              model,
              optimizer=TorchOptimizer(torch.optim.Adam, lr=1e-3),
              weighting=weighting)

trainer = Trainer(
    solver,
    max_epochs=100,
    train_size=0.8,
    test_size=0.2,
    batch_size=10,
    accelerator="cpu",
    enable_model_summary=False,
)

trainer.train()


model = Net_biased(2,1,2)
weighting = SelfAdaptiveWeighting()
solver = PINN(problem, 
              model,
              optimizer=TorchOptimizer(torch.optim.Adam, lr=1e-3),
              weighting=weighting)

trainer = Trainer(
    solver,
    max_epochs=100,
    train_size=0.8,
    test_size=0.2,
    batch_size=10,
    accelerator="cpu",
    enable_model_summary=False,
)

trainer.train()

Expected behavior
The model was supposed to be smoothly trained.

Output
Epoch 99: 100%|█| 1/1 [00:00<00:00, 97.48it/s, v_num=22, g1_loss_step=0.668, g2_loss_step=0.114, g3_loss_step=0.663, g4_loss_step=0Trainer.fitstopped:max_epochs=100reached. Epoch 99: 100%|█| 1/1 [00:00<00:00, 54.82it/s, v_num=22, g1_loss_step=0.668, g2_loss_step=0.114, g3_loss_step=0.663, g4_loss_step=0 💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry. GPU available: True (cuda), used: False TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs /u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/trainer/setup.py: PossibleUserWarning: GPU available but not used. You can set it by doingTrainer(accelerator='gpu'). /u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/trainer/configuration_validator.py: PossibleUserWarning: You defined a validation_stepbut have noval_dataloader. Skipping val loop. /u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/pina/solver/physics_informed_solver/pinn_interface.py: UserWarning: Compilation is disabled for torch >= 2.8. Forcing compilation may cause runtime errors or instability. /u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/loops/fit_loop.py: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch. Epoch 0: 0%| | 0/1 [00:00<?, ?it/s]Traceback (most recent call last): File "/u/g/gdinvern/Desktop/filter_learning/debug.py", line 75, in <module> trainer.train() ~~~~~~~~~~~~~^^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/pina/trainer.py", line 230, in train return super().fit(self.solver, datamodule=self.data_module, **kwargs) ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 560, in fit call._call_and_handle_interrupt( ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^ self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ) ^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/trainer/call.py", line 49, in _call_and_handle_interrupt return trainer_fn(*args, **kwargs) File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 598, in _fit_impl self._run(model, ckpt_path=ckpt_path) ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 1011, in _run results = self._run_stage() File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 1055, in _run_stage self.fit_loop.run() ~~~~~~~~~~~~~~~~~^^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/loops/fit_loop.py", line 216, in run self.advance() ~~~~~~~~~~~~^^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/loops/fit_loop.py", line 458, in advance self.epoch_loop.run(self._data_fetcher) ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 152, in run self.advance(data_fetcher) ~~~~~~~~~~~~^^^^^^^^^^^^^^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 348, in advance batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs) File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 192, in run self._optimizer_step(batch_idx, closure) ~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 270, in _optimizer_step call._call_lightning_module_hook( ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^ trainer, ^^^^^^^^ ...<4 lines>... train_step_and_backward_closure, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ) ^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/trainer/call.py", line 177, in _call_lightning_module_hook output = fn(*args, **kwargs) File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/core/module.py", line 1366, in optimizer_step optimizer.step(closure=optimizer_closure) ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/core/optimizer.py", line 154, in step step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs) File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/strategies/strategy.py", line 239, in optimizer_step return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/plugins/precision/precision.py", line 123, in optimizer_step return optimizer.step(closure=closure, **kwargs) ~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/torch/optim/lr_scheduler.py", line 133, in wrapper return func.__get__(opt, opt.__class__)(*args, **kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/torch/optim/optimizer.py", line 516, in wrapper out = func(*args, **kwargs) File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/torch/optim/optimizer.py", line 81, in _use_grad ret = func(*args, **kwargs) File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/torch/optim/adam.py", line 226, in step loss = closure() File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/plugins/precision/precision.py", line 109, in _wrap_closure closure_result = closure() File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 146, in __call__ self._result = self.closure(*args, **kwargs) ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context return func(*args, **kwargs) File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 131, in closure step_output = self._step_fn() File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 319, in _training_step training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values()) File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/trainer/call.py", line 329, in _call_strategy_hook output = fn(*args, **kwargs) File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/lightning/pytorch/strategies/strategy.py", line 391, in training_step return self.lightning_module.training_step(*args, **kwargs) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/pina/solver/solver.py", line 110, in training_step loss = self._optimization_cycle(batch=batch, **kwargs) File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/pina/solver/solver.py", line 255, in _optimization_cycle loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) ~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/pina/loss/weighting_interface.py", line 85, in aggregate self._saved_weights = self.weights_update(losses) ~~~~~~~~~~~~~~~~~~~^^^^^^^^ File "/u/g/gdinvern/miniconda3/envs/pina2/lib/python3.13/site-packages/pina/loss/self_adaptive_weighting.py", line 49, in weights_update [p.grad.flatten() for p in self.solver.model.parameters()] ^^^^^^^^^^^^^^ AttributeError: 'NoneType' object has no attribute 'flatten'

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions