Skip to content

Commit 8757f5f

Browse files
ArmavicaricardoV94
authored andcommitted
Use a Counter in tests/link/test_vm
1 parent cc1a1cb commit 8757f5f

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/link/test_vm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import time
2+
from collections import Counter
23

34
import numpy as np
45
import pytest
@@ -34,11 +35,10 @@ class TestCallbacks:
3435
# Test the `VMLinker`'s callback argument, which can be useful for debugging.
3536

3637
def setup_method(self):
37-
self.n_callbacks = {}
38+
self.n_callbacks = Counter()
3839

3940
def callback(self, node, thunk, storage_map, compute_map):
4041
key = node.op.__class__.__name__
41-
self.n_callbacks.setdefault(key, 0)
4242
self.n_callbacks[key] += 1
4343

4444
def test_callback(self):
@@ -50,9 +50,9 @@ def test_callback(self):
5050
)
5151

5252
f(1, 2, 3)
53-
assert sum(self.n_callbacks.values()) == len(f.maker.fgraph.toposort())
53+
assert self.n_callbacks.total() == len(f.maker.fgraph.toposort())
5454
f(1, 2, 3)
55-
assert sum(self.n_callbacks.values()) == len(f.maker.fgraph.toposort()) * 2
55+
assert self.n_callbacks.total() == len(f.maker.fgraph.toposort()) * 2
5656

5757
def test_callback_with_ifelse(self):
5858
a, b, c = scalars("abc")

0 commit comments

Comments
 (0)