|
41 | 41 | )
|
42 | 42 |
|
43 | 43 |
|
| 44 | +def _normalize_indices(context, builder, indty, inds, aryty): |
| 45 | + """ |
| 46 | + Convert integer indices into tuple of intp |
| 47 | + """ |
| 48 | + if indty in types.integer_domain: |
| 49 | + indty = types.UniTuple(dtype=indty, count=1) |
| 50 | + indices = [inds] |
| 51 | + else: |
| 52 | + indices = cgutils.unpack_tuple(builder, inds, count=len(indty)) |
| 53 | + indices = [ |
| 54 | + context.cast(builder, i, t, types.intp) for t, i in zip(indty, indices) |
| 55 | + ] |
| 56 | + |
| 57 | + if aryty.ndim != len(indty): |
| 58 | + raise TypeError( |
| 59 | + f"indexing {aryty.ndim}-D array with {len(indty)}-D index" |
| 60 | + ) |
| 61 | + |
| 62 | + return indty, indices |
| 63 | + |
| 64 | + |
44 | 65 | def _parse_enum_or_int_literal_(literal_int) -> int:
|
45 | 66 | """Parse an instance of an enum class or numba.core.types.Literal to its
|
46 | 67 | actual int value.
|
@@ -208,23 +229,22 @@ def _intrinsic_atomic_ref_ctor(
|
208 | 229 | sig = ty_retty(ref, ty_index, ty_retty_ref)
|
209 | 230 |
|
210 | 231 | def codegen(context, builder, sig, args):
|
211 |
| - ref = args[0] |
212 |
| - index_pos = args[1] |
| 232 | + aryty, indty, _ = sig.args |
| 233 | + ary, inds, _ = args |
213 | 234 |
|
214 |
| - dmm = context.data_model_manager |
215 |
| - data_attr_pos = dmm.lookup(sig.args[0]).get_field_position("data") |
216 |
| - data_attr = builder.extract_value(ref, data_attr_pos) |
| 235 | + indty, indices = _normalize_indices( |
| 236 | + context, builder, indty, inds, aryty |
| 237 | + ) |
217 | 238 |
|
218 |
| - with builder.goto_entry_block(): |
219 |
| - ptr_to_data_attr = builder.alloca(data_attr.type) |
220 |
| - builder.store(data_attr, ptr_to_data_attr) |
221 |
| - ref_ptr_value = builder.gep(builder.load(ptr_to_data_attr), [index_pos]) |
| 239 | + lary = context.make_array(aryty)(context, builder, ary) |
| 240 | + ref_ptr_value = cgutils.get_item_pointer( |
| 241 | + context, builder, aryty, lary, indices, wraparound=True |
| 242 | + ) |
222 | 243 |
|
223 | 244 | atomic_ref_struct = cgutils.create_struct_proxy(ty_retty)(
|
224 | 245 | context, builder
|
225 | 246 | )
|
226 |
| - ref_attr_pos = dmm.lookup(ty_retty).get_field_position("ref") |
227 |
| - atomic_ref_struct[ref_attr_pos] = ref_ptr_value |
| 247 | + atomic_ref_struct.ref = ref_ptr_value |
228 | 248 | # pylint: disable=protected-access
|
229 | 249 | return atomic_ref_struct._getvalue()
|
230 | 250 |
|
|
0 commit comments