File tree Expand file tree Collapse file tree 1 file changed +2
-1
lines changed Expand file tree Collapse file tree 1 file changed +2
-1
lines changed Original file line number Diff line number Diff line change @@ -75,6 +75,7 @@ def get_cam_weights(self,
75
75
retain_graph = False ,
76
76
allow_unused = True
77
77
)[0 ]
78
+ # print(torch.max(hvp[0]).item()) # verify that hvp is not all zeros
78
79
if hvp is None :
79
80
hvp = torch .tensor (0 ).to (self .device )
80
81
elif self .activations_and_grads .reshape_transform is not None :
@@ -83,7 +84,7 @@ def get_cam_weights(self,
83
84
if self .activations_and_grads .reshape_transform is not None :
84
85
activations = self .activations_and_grads .reshape_transform (activations )
85
86
grads = self .activations_and_grads .reshape_transform (grads )
86
- weight = (grads - 0.5 * hvp ).cpu ().detach ().numpy ()
87
+ weight = (grads - 0.5 * hvp ).cpu ().detach ().numpy ()
87
88
activations = activations .cpu ().detach ().numpy ()
88
89
grads = grads .cpu ().detach ().numpy ()
89
90
You can’t perform that action at this time.
0 commit comments