Skip to content

Commit 6e74aea

Browse files
author
Diptorup Deb
committed
Add flattened_field_count property to RangeModel and NdRangeModel.
1 parent 1dda097 commit 6e74aea

File tree

1 file changed

+36
-40
lines changed

1 file changed

+36
-40
lines changed

numba_dpex/core/datamodel/models.py

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,30 @@
2020
)
2121

2222

23+
def _get_flattened_member_count(ty):
24+
"""Return the number of fields in an instance of a given StructModel."""
25+
flattened_member_count = 0
26+
members = ty._members
27+
for member in members:
28+
if isinstance(member, types.UniTuple):
29+
flattened_member_count += member.count
30+
elif isinstance(
31+
member,
32+
(
33+
types.scalars.Integer,
34+
types.misc.PyObject,
35+
types.misc.RawPointer,
36+
types.misc.CPointer,
37+
types.misc.MemInfoPointer,
38+
),
39+
):
40+
flattened_member_count += 1
41+
else:
42+
raise UnreachableError
43+
44+
return flattened_member_count
45+
46+
2347
class GenericPointerModel(PrimitiveModel):
2448
def __init__(self, dmm, fe_type):
2549
adrsp = (
@@ -68,26 +92,7 @@ def __init__(self, dmm, fe_type):
6892
@property
6993
def flattened_field_count(self):
7094
"""Return the number of fields in an instance of a USMArrayModel."""
71-
flattened_member_count = 0
72-
members = self._members
73-
for member in members:
74-
if isinstance(member, types.UniTuple):
75-
flattened_member_count += member.count
76-
elif isinstance(
77-
member,
78-
(
79-
types.scalars.Integer,
80-
types.misc.PyObject,
81-
types.misc.RawPointer,
82-
types.misc.CPointer,
83-
types.misc.MemInfoPointer,
84-
),
85-
):
86-
flattened_member_count += 1
87-
else:
88-
raise UnreachableError
89-
90-
return flattened_member_count
95+
return _get_flattened_member_count(self)
9196

9297

9398
class DpnpNdArrayModel(StructModel):
@@ -121,26 +126,7 @@ def __init__(self, dmm, fe_type):
121126
@property
122127
def flattened_field_count(self):
123128
"""Return the number of fields in an instance of a DpnpNdArrayModel."""
124-
flattened_member_count = 0
125-
members = self._members
126-
for member in members:
127-
if isinstance(member, types.UniTuple):
128-
flattened_member_count += member.count
129-
elif isinstance(
130-
member,
131-
(
132-
types.scalars.Integer,
133-
types.misc.PyObject,
134-
types.misc.RawPointer,
135-
types.misc.CPointer,
136-
types.misc.MemInfoPointer,
137-
),
138-
):
139-
flattened_member_count += 1
140-
else:
141-
raise UnreachableError
142-
143-
return flattened_member_count
129+
return _get_flattened_member_count(self)
144130

145131

146132
class SyclQueueModel(StructModel):
@@ -211,6 +197,11 @@ def __init__(self, dmm, fe_type):
211197
]
212198
super(RangeModel, self).__init__(dmm, fe_type, members)
213199

200+
@property
201+
def flattened_field_count(self):
202+
"""Return the number of fields in an instance of a RangeModel."""
203+
return _get_flattened_member_count(self)
204+
214205

215206
class NdRangeModel(StructModel):
216207
"""The native data model for a
@@ -229,6 +220,11 @@ def __init__(self, dmm, fe_type):
229220
]
230221
super(NdRangeModel, self).__init__(dmm, fe_type, members)
231222

223+
@property
224+
def flattened_field_count(self):
225+
"""Return the number of fields in an instance of a NdRangeModel."""
226+
return _get_flattened_member_count(self)
227+
232228

233229
def _init_data_model_manager() -> datamodel.DataModelManager:
234230
"""Initializes a DpexKernelTarget-specific data model manager.

0 commit comments

Comments
 (0)