diff --git a/pytorch_grad_cam/utils/svd_on_activations.py b/pytorch_grad_cam/utils/svd_on_activations.py index a406aeea..64df76e1 100644 --- a/pytorch_grad_cam/utils/svd_on_activations.py +++ b/pytorch_grad_cam/utils/svd_on_activations.py @@ -1,19 +1,29 @@ -import numpy as np -def get_2d_projection(activation_batch): - # TBD: use pytorch batch svd implementation - activation_batch[np.isnan(activation_batch)] = 0 - projections = [] - for activations in activation_batch: - reshaped_activations = (activations).reshape( - activations.shape[0], -1).transpose() - # Centering before the SVD seems to be important here, - # Otherwise the image returned is negative - reshaped_activations = reshaped_activations - \ - reshaped_activations.mean(axis=0) - U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True) - projection = reshaped_activations @ VT[0, :] - projection = projection.reshape(activations.shape[1:]) - projections.append(projection) - return np.float32(projections) +def get_2d_projection( batch_activations): + + batch_activations = torch.from_numpy( batch_activations) + + b, c, h, w = batch_activations.shape + + #x = rearrange(batch_activations, "b c h w -> b (h w) c") + x = batch_activations.reshape( b, c, h * w).permute( 0, 2, 1) + + x_mean = x.mean(1, keepdim=True) + + x = x - x_mean + + U, S, VT = torch.linalg.svd( x ) + + #transpose + + #V = rearrange(VT, 'a b c -> a c b') + V = VT.permute( 0, 2, 1) + V = V[ :, :, 0 : 1 ] + + projection = torch.bmm(x, V).squeeze( -1 ) + + #projection = rearrange( projection, 'b (h w) -> b h w', h = h, w = w) + projection = projection.reshape( b, h, w) + + return projection.detach().numpy( )