Skip to content

Problem about the loss #3

@FeiiYin

Description

@FeiiYin

Hi, Yang:

I've read your Leba work and I'm very interested in it. Recently, I'm trying to reproduce the part that uses forward loss and backward loss to narrow the gap between the two models. However, I feel confused about the code. I hope you can answer it. Thank you very much

I imitate the code in Git and try to calculate the backward loss as follows:

images = images.detach().clone()
adv_images = images + diff

surrogate_logits = surrogate_model(images)
surrogate_loss = nn.CrossEntropyLoss(reduction='none')(surrogate_logits, labels)

grad = torch.autograd.grad(surrogate_loss.sum(), images, create_graph=True)[0]
s_loss = (diff.detach() * grad).view([images.shape[0], -1]).sum(dim=1)  # scalar

target_adv_logits = target_model(adv_images)
target_adv_loss = nn.CrossEntropyLoss(reduction='none')(target_adv_logits, labels)

target_ori_logits = target_model(images)
target_ori_loss = nn.CrossEntropyLoss(reduction='none')(target_ori_logits, labels)
d_loss = torch.log(target_adv_loss) - torch.log(target_ori_loss)  # scalar

backward_loss = nn.MSELoss()(s_loss / lamda, d_loss.detach())

However, based on the above implementation, the gap between the output of the surrogate model and the target model gradually widens uncontrollably. I tested the gap via torch.nn.MSELoss(reduce=True, size_average=False)(target_model(images), surrogate_model(images)).

Whats's more, I imitate the code in Git and try to calculate the forward loss as follows:

surrogate_logits = surrogate_model(images)
surrogate_prob = torch.nn.functional.softmax(surrogate_logits, dim=1)
s_score = surrogate_prob.gather(1, labels.reshape([-1, 1]))

target_logits = target_model(images)
target_prob = torch.nn.functional.softmax(target_logits, dim=1)
target_score = target_prob.gather(1, labels.reshape([-1, 1]))

forward_loss = nn.MSELoss()(s_score, target_score.detach())

And using forward loss the gap does not show a monotonous downward trend.
I would like to ask which part of my understanding is wrong. :(

Yours.

Fei

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions