Skip to content

I think you have a error. #19

@typhoon1104

Description

@typhoon1104

DEEP_CORAL_LOSS:
def CORAL(source, target):
d = source.data.shape[1]
ns = source.data.shape[0]
nt = target.data.shape[0]

# source covariance
xm = torch.mean(source, 0, keepdim=True) - source
xc = (xm.t() @ xm) / (ns-1)

# target covariance
xmt = torch.mean(target, 0, keepdim=True) - target
xct = xmt.t() @ xmt / (nt-1)

print(xc, xct)

# frobenius norm between source and target
loss = torch.sum(torch.mul((xc - xct), (xc - xct)))
loss = loss/(4*d*d)
return loss

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions