Skip to content

Commit 9e489b1

Browse files
Fix Problem of .view(-1) in TensorRecorder Util
1 parent 2277543 commit 9e489b1

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

DeepQuant/Utils/TensorRecorder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,11 @@ def compareTensors(self) -> Dict[str, Dict]:
9999
def _topDifferences(
100100
self, ref: torch.Tensor, cur: torch.Tensor, diffMask: torch.Tensor
101101
) -> List[str]:
102-
maskFlat = diffMask.view(-1).bool()
102+
maskFlat = diffMask.reshape(-1).bool()
103103
if maskFlat.sum() == 0:
104104
return []
105105

106-
absDiff = (ref - cur).abs().view(-1)[maskFlat]
106+
absDiff = (ref - cur).abs().reshape(-1)[maskFlat]
107107
unique, counts = torch.unique(absDiff, return_counts=True)
108108
order = counts.argsort(descending=True)
109109

@@ -113,8 +113,8 @@ def _topDifferences(
113113
count = counts[idx].item()
114114
sampleIndex = (absDiff == delta).nonzero(as_tuple=False)[0].item()
115115
globalIndex = maskFlat.nonzero(as_tuple=False)[sampleIndex].item()
116-
beforeValue = ref.view(-1)[globalIndex].item()
117-
afterValue = cur.view(-1)[globalIndex].item()
116+
beforeValue = ref.reshape(-1)[globalIndex].item()
117+
afterValue = cur.reshape(-1)[globalIndex].item()
118118

119119
lines.append(
120120
f" · Δ={delta:.6f} ({count} values) e.g. idx {globalIndex}: "

0 commit comments

Comments
 (0)