diff --git a/easy_tpp/model/torch_model/torch_fullynn.py b/easy_tpp/model/torch_model/torch_fullynn.py index 11af6da..52ae351 100644 --- a/easy_tpp/model/torch_model/torch_fullynn.py +++ b/easy_tpp/model/torch_model/torch_fullynn.py @@ -64,13 +64,13 @@ def forward(self, hidden_states, time_delta_seqs): derivative_integral_lambdas = [] for i in range(integral_lambda.shape[-1]): # iterate over marks derivative_integral_lambdas.append(grad( - integral_lambda[..., i].mean(), + integral_lambda[..., i].sum(), time_delta_seqs, create_graph=True, retain_graph=True)[0]) derivative_integral_lambda = torch.stack(derivative_integral_lambdas, dim=-1) # TODO: Check that it is okay to iterate over marks like this else: - derivative_integral_lambda = grad( - integral_lambda.sum(dim=-1).mean(), + derivative_integral_lambda = grad( + integral_lambda.sum(), time_delta_seqs, create_graph=True, retain_graph=True)[0] derivative_integral_lambda = derivative_integral_lambda.unsqueeze(-1).expand(*derivative_integral_lambda.shape, self.num_event_types) / self.num_event_types