Skip to content

Commit 72f66dd

Browse files
committed
Organize array datamodel tests
1 parent 69cb43e commit 72f66dd

File tree

3 files changed

+62
-81
lines changed

3 files changed

+62
-81
lines changed

numba_dpex/tests/core/types/DpnpNdArray/test_models.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

numba_dpex/tests/core/types/USMNdArray/test_models.py

Lines changed: 0 additions & 27 deletions
This file was deleted.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# SPDX-FileCopyrightText: 2020 - 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import pytest
6+
from numba import types
7+
from numba.core.datamodel import default_manager, models
8+
from numba.core.registry import cpu_target
9+
10+
from numba_dpex.core.datamodel.models import (
11+
USMArrayDeviceModel,
12+
USMArrayHostModel,
13+
dpex_data_model_manager,
14+
)
15+
from numba_dpex.core.types.dpnp_ndarray_type import DpnpNdArray, USMNdArray
16+
17+
18+
@pytest.fixture(
19+
params=[
20+
array(ndim=ndim, dtype=types.float32, layout="C")
21+
for array in [DpnpNdArray, USMNdArray]
22+
for ndim in range(4)
23+
]
24+
)
25+
def nd_array(request):
26+
return request.param
27+
28+
29+
def test_model_for_array(nd_array):
30+
"""Test the datamodel for DpnpNdArray that is registered with numba's
31+
default datamodel manager and numba_dpex's kernel data model manager.
32+
"""
33+
device_model = dpex_data_model_manager.lookup(nd_array)
34+
assert isinstance(device_model, USMArrayDeviceModel)
35+
host_model = default_manager.lookup(nd_array)
36+
assert isinstance(host_model, USMArrayHostModel)
37+
38+
39+
def test_dpnp_usm_ndarray_model():
40+
"""Test for dpctl.tensor.usm_ndarray model.
41+
42+
It is a subclass of models.StructModel and models.ArrayModel.
43+
"""
44+
45+
assert issubclass(USMArrayHostModel, models.StructModel)
46+
assert issubclass(USMArrayDeviceModel, models.StructModel)
47+
48+
49+
def test_flattened_member_count(nd_array):
50+
"""Test that the number of flattened member count matches the number of
51+
flattened args generated by the CpuTarget's ArgPacker.
52+
"""
53+
54+
cputargetctx = cpu_target.target_context
55+
dpex_dmm = cputargetctx.data_model_manager
56+
57+
argty_tuple = tuple([nd_array])
58+
datamodel = dpex_dmm.lookup(nd_array)
59+
num_flattened_args = datamodel.flattened_field_count
60+
ap = cputargetctx.get_arg_packer(argty_tuple)
61+
62+
assert num_flattened_args == len(ap._be_args)

0 commit comments

Comments
 (0)