How to access validation step outputs of complete epoch in a on_validation_epoch_end
hook for a custom callback ?
#11659
-
I want to implement a custom callback which calculates a custom metric and needs all of the outputs from the complete epoch. Is there any way to pass all the outputs to Here's the pseudo-code of the setup class FeedBackPrize(pl.LightningModule):
def __init__(
self,
num_train_steps,
steps_per_epoch,
model_name: str = "allenai/longformer-base-4096",
lr: float = 1e-5,
num_labels: int = 16,
multi_sample_dropout=True,
step_scheduler_after: str = "step",
):
super().__init__()
self.learning_rate = lr
self.model_name = model_name
self.multi_sample_dropout = multi_sample_dropout
self.num_train_steps = num_train_steps
self.num_labels = num_labels
self.steps_per_epoch = steps_per_epoch
self.step_scheduler_after = step_scheduler_after
hidden_dropout_prob: float = 0.1
layer_norm_eps: float = 1e-7
config = AutoConfig.from_pretrained(model_name)
config.update(
{
"output_hidden_states": True,
"hidden_dropout_prob": hidden_dropout_prob,
"layer_norm_eps": layer_norm_eps,
"add_pooling_layer": False,
"num_labels": self.num_labels,
}
)
self.transformer = AutoModel.from_pretrained(model_name, config=config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.dropout1 = nn.Dropout(0.1)
self.dropout2 = nn.Dropout(0.2)
self.dropout3 = nn.Dropout(0.3)
self.dropout4 = nn.Dropout(0.4)
self.dropout5 = nn.Dropout(0.5)
self.output = nn.Linear(config.hidden_size, self.num_labels)
def forward(self, ids, mask, token_type_ids=None):
transformer_out = self.transformer(ids, mask)
sequence_output = transformer_out.last_hidden_state
sequence_output = self.dropout(sequence_output)
if self.multi_sample_dropout:
logits1 = self.output(self.dropout1(sequence_output))
logits2 = self.output(self.dropout2(sequence_output))
logits3 = self.output(self.dropout3(sequence_output))
logits4 = self.output(self.dropout4(sequence_output))
logits5 = self.output(self.dropout5(sequence_output))
logits = (logits1 + logits2 + logits3 + logits4 + logits5) / 5
logits = torch.softmax(logits, dim=-1)
return logits
else:
return sequence_output
def configure_optimizers(self):
param_optimizer = list(self.named_parameters())
no_decay = ["bias", "LayerNorm.bias"]
optimizer_parameters = [
{
"params": [
p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
],
"weight_decay": 0.01,
},
{
"params": [
p for n, p in param_optimizer if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_parameters, lr=self.learning_rate)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=int(0.1 * self.num_train_steps),
num_training_steps=self.num_train_steps,
num_cycles=1,
last_epoch=-1,
)
scheduler = {
"scheduler": scheduler,
"interval": self.step_scheduler_after,
"frequency": 1,
}
return [optimizer], [scheduler]
def _calculate_loss(self, outputs, targets, attention_mask):
loss_fct = nn.CrossEntropyLoss()
active_loss = attention_mask.view(-1) == 1
active_logits = outputs.view(-1, self.num_labels)
true_labels = targets.view(-1)
outputs = active_logits.argmax(dim=-1)
idxs = np.where(active_loss.cpu().numpy() == 1)[0]
active_logits = active_logits[idxs]
true_labels = true_labels[idxs].to(torch.long)
loss = loss_fct(active_logits, true_labels)
return loss
def training_step(self, batch, batch_idx):
ids, mask, targets = batch['ids'], batch['mask'], batch['targets']
outputs = self(ids, mask)
loss = self._calculate_loss(outputs, targets, mask)
return loss
def validation_step(self, batch, batch_idx):
ids, mask, targets = batch['ids'], batch['mask'], batch['targets']
outputs = self(ids, mask)
loss = self._calculate_loss(outputs, targets, mask)
return {
"loss": loss,
"preds": outputs,
"targets": targets
}
def validation_epoch_end(self, validation_step_outputs):
preds = []
targets = []
for output in validation_step_outputs:
preds += output['preds']
targets += output['targets']
targets = torch.stack(targets) #torch.Size([2, 1536])
preds = torch.stack(preds) # torch.Size([2, 1536, 15])
return {
"targets": targets,
"preds": preds
} Custom callbackclass CompMetricEvaluator(Callback):
def __init__(self):
pass
def on_validation_epoch_end(self, trainer, pl_module):
print("After validation epoch [custom metric evaluation]")
# calculate custom metric here.... |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
hey @Gladiator07! you can either override class CustomCallback(Callback):
def __init__(self):
self.val_outs = []
def on_validation_batch_end(self, trainer, pl_module, outputs, ...):
self.val_outs.append(outputs)
def on_validation_epoch_end(self, trainer, pl_module):
self.val_outs # <- access them here or cache the val outputs in pl_module inside class LitModel(LightningModule):
def validation_epoch_end(self, outputs):
new_outputs = ...
self.val_outs = new_outputs
class CustomCallback(Callback):
def on_validation_epoch_end(self, trainer, pl_module):
pl_module.val_outs # <- access them here note that the trainer and pl_module passed inside callbacks are passed by reference so that ever changes in the original lightningmodule will reflect in this referred instance here too. |
Beta Was this translation helpful? Give feedback.
hey @Gladiator07!
you can either override
on_validation_batch_end
hook and cache the outputs in some variable of the callback use that.or cache the val outputs in pl_module inside
validation_epoch_end