Skip to content

Commit 6c80f8b

Browse files
aobolensklantigat-viriccardofelluga
authored
Add tests for NumPy language context and fix import path (#2690)
Co-authored-by: Luca Antiga <luca@lightning.ai> Co-authored-by: Thomas Viehmann <tv.code@beamnet.de> Co-authored-by: Riccardo Felluga <11768013+riccardofelluga@users.noreply.github.com>
1 parent ecc3a79 commit 6c80f8b

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

thunder/numpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from numbers import Number
22
from collections.abc import Callable
33

4-
from thunder.core.langctx import langctx, Languages
4+
from thunder.core.langctxs import langctx, Languages
55
from thunder.numpy.langctx import register_method
66

77
from thunder.core.proxies import TensorProxy
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from thunder.numpy import size as np_size
2+
from thunder.core.langctxs import langctx, Languages, resolve_language
3+
from thunder.core.proxies import TensorProxy
4+
from thunder.core.trace import detached_trace
5+
from thunder.core.devices import cpu
6+
from thunder.core.dtypes import float32
7+
8+
9+
def test_numpy_langctx_registration_and_len_size():
10+
with detached_trace():
11+
t = TensorProxy(shape=(2, 3), device=cpu, dtype=float32)
12+
13+
with langctx(Languages.NUMPY):
14+
assert len(t) == 2 # axis 0 length
15+
assert t.size() == 6 # total elements
16+
assert np_size(t) == 6
17+
18+
19+
def test_numpy_langctx_resolve_language():
20+
numpy_ctx_by_enum = resolve_language(Languages.NUMPY)
21+
numpy_ctx_by_name = resolve_language("numpy")
22+
23+
assert numpy_ctx_by_enum is numpy_ctx_by_name
24+
assert numpy_ctx_by_enum.name == "numpy"

0 commit comments

Comments
 (0)