Skip to content

A gracefull design to introduce third-party models as tool for validation #20464

@JohnHerry

Description

@JohnHerry

Description & Motivation

python3.10.12 + pytorch_lightning 2.4.0
I need a gracefull design to introduce third-party pretrained models for use during the validation steps. so that there is no such Error reported:

RuntimeError: It looks like your LightningModule has parameters that were not used in producing the loss returned by training_step. If this is intentional, you must enable the detection of unused parameters in DDP, ....

Pitch

I am training a model which need other third-party pretrained model during validation. example:
the third party model:

class PretrainedPicGen(torch.nn.Module):
    def __init__(self, pretrained_path):
        self.backbone = load_checkpoint(pretrained_path)
   def forward(self, to_validate):
        return self.backbone(to_validate)

And the lightning project I am training:

class MyModel(pl.LightningModule):
    def __init__(self, my_param, third_party_pretrained_path):
         ....
        self.pretrained_pic_gen = PretrainedPicGen(third_party_pretrained_path)
        self.validation_outs = []
    ....
    def validation_step(self, batch, *args, **kwargs):
         validation_output = self.sample(....)
         self.validation_outputs.append({"vali_out": validation_output})

    def on_validation_epoch_end(self) :   # Here we use the third party model for post processing the validation out
         outputs = self.validation_outputs
         for i, output in enumerate(outputs):
               visible_output = self.pretrained_pic_gen(output)
               self.logger.experiment.add_image(f"validate/{i}", visible_output, self.global_step)

and the config file yaml:

model:
     class_path: myproject.MyModel
     init_args:
        my_param: 1234
        third_party_pretrained_path: /path/to/third_party_pretrained

but When I run the training, there report the Error information as mentioned before:

RuntimeError: It looks like your LightningModule has parameters that were not used in producing the loss returned by training_step. If this is intentional, you must enable the detection of unused parameters in DDP, ....

And I think to config the strategy=ddp_find_unused_parameters_true may be not good solution, is there any gracefull design here? for example, support extra parameters in the on_validation_epoch_end callback and provide a gracefull third_party initialization supported in the config file.

Alternatives

No response

Additional context

No response

cc @Borda @tchaton @justusschock @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Labels

    designIncludes a design discussionfeatureIs an improvement or enhancement

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions