From 5fae3d96649db4576cad0457b3b1a2b9c2862625 Mon Sep 17 00:00:00 2001 From: hongkonghector <136786589+hongkonghector@users.noreply.github.com> Date: Fri, 16 Jun 2023 17:53:40 +0800 Subject: [PATCH 1/2] Update svd_on_activations.py --- pytorch_grad_cam/utils/svd_on_activations.py | 43 ++++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/pytorch_grad_cam/utils/svd_on_activations.py b/pytorch_grad_cam/utils/svd_on_activations.py index a406aeea8..4a3013b09 100644 --- a/pytorch_grad_cam/utils/svd_on_activations.py +++ b/pytorch_grad_cam/utils/svd_on_activations.py @@ -1,19 +1,28 @@ -import numpy as np +import torch -def get_2d_projection(activation_batch): - # TBD: use pytorch batch svd implementation - activation_batch[np.isnan(activation_batch)] = 0 - projections = [] - for activations in activation_batch: - reshaped_activations = (activations).reshape( - activations.shape[0], -1).transpose() - # Centering before the SVD seems to be important here, - # Otherwise the image returned is negative - reshaped_activations = reshaped_activations - \ - reshaped_activations.mean(axis=0) - U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True) - projection = reshaped_activations @ VT[0, :] - projection = projection.reshape(activations.shape[1:]) - projections.append(projection) - return np.float32(projections) +def get_2d_projection( batch_activations): + + b, c, h, w = batch_activations.shape + + #x = rearrange(batch_activations, "b c h w -> b (h w) c") + x = batch_activations.reshape( b, c, h * w).permute( 0, 2, 1) + + x_mean = x.mean(1, keepdim=True) + + x = x - x_mean + + U, S, VT = torch.linalg.svd( x ) + + #transpose + + #V = rearrange(VT, 'a b c -> a c b') + V = VT.permute( 0, 2, 1) + V = V[ :, :, 0 : 1 ] + + projection = torch.bmm(x, V).squeeze( ) + + #projection = rearrange( projection, 'b (h w) -> b h w', h = h, w = w) + projection = projection.reshape( b, h, w) + + return projection From 308da1dc96b57fa87b753304bc8f0fa6fc88419e Mon Sep 17 00:00:00 2001 From: hongkonghector <136786589+hongkonghector@users.noreply.github.com> Date: Fri, 16 Jun 2023 18:50:49 +0800 Subject: [PATCH 2/2] Update svd_on_activations.py Rewrote the code to use the torch batch SVD. The code still assumes that the input and output must be numpy arrays, so they are converted to torch tensors and then back to numpy arrays. --- pytorch_grad_cam/utils/svd_on_activations.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_grad_cam/utils/svd_on_activations.py b/pytorch_grad_cam/utils/svd_on_activations.py index 4a3013b09..64df76e15 100644 --- a/pytorch_grad_cam/utils/svd_on_activations.py +++ b/pytorch_grad_cam/utils/svd_on_activations.py @@ -1,8 +1,9 @@ -import torch def get_2d_projection( batch_activations): + batch_activations = torch.from_numpy( batch_activations) + b, c, h, w = batch_activations.shape #x = rearrange(batch_activations, "b c h w -> b (h w) c") @@ -20,9 +21,9 @@ def get_2d_projection( batch_activations): V = VT.permute( 0, 2, 1) V = V[ :, :, 0 : 1 ] - projection = torch.bmm(x, V).squeeze( ) + projection = torch.bmm(x, V).squeeze( -1 ) #projection = rearrange( projection, 'b (h w) -> b h w', h = h, w = w) projection = projection.reshape( b, h, w) - return projection + return projection.detach().numpy( )