|
upd_matrix = torch.linalg.solve( |
|
P[i,:,:].cuda() @ (layer_ks @ layer_ks.T) + hparams.L2*torch.eye(layer_ks.shape[0], dtype=torch.float,device="cuda"), P[i,:,:].cuda() @ layer_ks @ resid |
|
) |
This is different to the original AlphaEdit implementation
https://github.com/jianghoucheng/AlphaEdit/blob/75e117bcbbc58b1fd28df41e1cbb2fa3d25e3ef1/AlphaEdit/AlphaEdit_main.py#L130-L132
