Skip to content

Commit 84de964

Browse files
author
Diptorup Deb
committed
Unit tests for range/ndrange data models.
1 parent 57c0d51 commit 84de964

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import pytest
6+
from numba.core.datamodel import default_manager
7+
8+
from numba_dpex.core.datamodel.models import (
9+
NdRangeModel,
10+
RangeModel,
11+
dpex_data_model_manager,
12+
)
13+
from numba_dpex.core.types.range_types import NdRangeType, RangeType
14+
15+
rfields = ["ndim", "dim0", "dim1", "dim2"]
16+
ndrfields = ["ndim", "gdim0", "gdim1", "gdim2", "ldim0", "ldim1", "ldim2"]
17+
18+
19+
def test_datamodel_registration():
20+
"""Test the datamodel for RangeType and NdRangeType are found in numba's
21+
default datamodel manager but not in numba_dpex's kernel data model manager.
22+
"""
23+
range_ty = RangeType(ndim=1)
24+
ndrange_ty = NdRangeType(ndim=1)
25+
26+
with pytest.raises(KeyError):
27+
dpex_data_model_manager.lookup(range_ty)
28+
dpex_data_model_manager.lookup(ndrange_ty)
29+
30+
default_range_model = default_manager.lookup(range_ty)
31+
default_ndrange_model = default_manager.lookup(ndrange_ty)
32+
33+
assert isinstance(default_range_model, RangeModel)
34+
assert isinstance(default_ndrange_model, NdRangeModel)
35+
36+
37+
@pytest.mark.parametrize("field", rfields)
38+
def test_range_model_fields(field):
39+
"""Tests that the expected fields are found in the data model for
40+
RangeType
41+
"""
42+
range_ty = RangeType(ndim=1)
43+
dm = default_manager.lookup(range_ty)
44+
try:
45+
dm.get_field_position(field)
46+
except:
47+
pytest.fail(f"Expected field {field} not present in RangeModel")
48+
49+
50+
@pytest.mark.parametrize("field", ndrfields)
51+
def test_ndrange_model_fields(field):
52+
"""Tests that the expected fields are found in the data model for
53+
NdRangeType
54+
"""
55+
ndrange_ty = NdRangeType(ndim=1)
56+
dm = default_manager.lookup(ndrange_ty)
57+
try:
58+
dm.get_field_position(field)
59+
except:
60+
pytest.fail(f"Expected field {field} not present in NdRangeModel")

0 commit comments

Comments
 (0)