Skip to content

Commit 08a5a82

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
scalar metric fix window count state size (#3302)
Summary: Pull Request resolved: #3302 make window_count size same as other state tensors. no issues with checkpointing as these states are not saved to state dict (persistent=False) Reviewed By: irobert0126 Differential Revision: D80275353 fbshipit-source-id: 0176fc551a9629480d416eaf4abb896447fa13b2
1 parent 7f6e7fb commit 08a5a82

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torchrec/metrics/scalar.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@ def update(
5353

5454
states = {
5555
"labels": labels.mean(dim=-1),
56-
"window_count": torch.tensor([1.0]).to(
57-
labels.device
56+
"window_count": torch.ones(
57+
torch.Size([self._n_tasks]),
58+
device=labels.device,
59+
dtype=torch.double,
5860
), # put window count on the correct device
5961
}
6062
for state_name, state_value in states.items():

0 commit comments

Comments
 (0)