-
Notifications
You must be signed in to change notification settings - Fork 240
Open
Description
Cream/AutoFormer/model/module/qkv_super.py
Lines 72 to 77 in 4a13c40
| def sample_weight(weight, sample_in_dim, sample_out_dim): | |
| sample_weight = weight[:, :sample_in_dim] | |
| sample_weight = torch.cat([sample_weight[i:sample_out_dim:3, :] for i in range(3)], dim =0) | |
| return sample_weight |
I think, there's something wrong in the way weight sharing is done here. I think this code should be:
N = weight.size(0) // 3
sample_weight = torch.cat([sample_weight[i*N:i*N+sample_out_dim//3, :] for i in range(3)], dim=0)To be more intuitive, I drew a schematic diagram to represent the way 4 and 5 heads SA is shared with Linear.weight.
Maybe I misunderstood the implementation here, can you help check it?
Metadata
Metadata
Assignees
Labels
No labels
