Skip to content

Commit c74a3c1

Browse files
Merge pull request #2762 from devitocodes/dereference-expr
compiler: Enhance Dereference
2 parents ebe4846 + 75c80d8 commit c74a3c1

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

devito/ir/iet/nodes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1057,10 +1057,11 @@ class Dereference(ExprStmt, Node):
10571057

10581058
is_Dereference = True
10591059

1060-
def __init__(self, pointee, pointer, flat=None):
1060+
def __init__(self, pointee, pointer, flat=None, offset=None):
10611061
self.pointee = pointee
10621062
self.pointer = pointer
10631063
self.flat = flat
1064+
self.offset = offset
10641065

10651066
def __repr__(self):
10661067
return "<Dereference(%s,%s)>" % (self.pointee, self.pointer)
@@ -1088,6 +1089,9 @@ def expr_symbols(self):
10881089
else:
10891090
assert False, f"Unexpected pointer type {type(self.pointer)}"
10901091

1092+
if self.offset is not None:
1093+
ret.append(self.offset)
1094+
10911095
return tuple(filter_ordered(ret))
10921096

10931097
@property

devito/ir/iet/visitors.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,12 @@ def visit_PointerCast(self, o):
506506

507507
def visit_Dereference(self, o):
508508
a0, a1 = o.functions
509+
510+
if o.offset:
511+
ptr = f'({a1.name} + {o.offset})'
512+
else:
513+
ptr = a1.name
514+
509515
if a0.is_AbstractFunction:
510516
cstr = self.ccode(a0.indexed._C_typedata)
511517

@@ -517,17 +523,17 @@ def visit_Dereference(self, o):
517523

518524
if o.flat is None:
519525
shape = ''.join(f"[{self.ccode(i)}]" for i in a0.symbolic_shape[1:])
520-
rvalue = f'({cstr} (*){shape}) {a1.name}{cdim}'
526+
rvalue = f'({cstr} (*){shape}) {ptr}{cdim}'
521527
lvalue = c.Value(cstr, f'(*{self._restrict_keyword} {a0.name}){shape}')
522528
else:
523-
rvalue = f'({cstr} *) {a1.name}{cdim}'
529+
rvalue = f'({cstr} *) {ptr}{cdim}'
524530
lvalue = c.Value(cstr, f'*{self._restrict_keyword} {a0.name}')
525531

526532
else:
527533
if a1.is_Symbol:
528-
rvalue = f'*{a1.name}'
534+
rvalue = f'*{ptr}'
529535
else:
530-
rvalue = f'{a1.name}->{a0._C_name}'
536+
rvalue = f'{ptr}->{a0._C_name}'
531537
lvalue = self._gen_value(a0, 0)
532538

533539
return c.Initializer(lvalue, rvalue)

tests/test_iet.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99
switchconfig)
1010
from devito.ir.iet import (
1111
Call, Callable, Conditional, Definition, DeviceCall, DummyExpr, Iteration, List,
12-
KernelLaunch, Lambda, ElementalFunction, CGen, FindSymbols, filter_iterations,
13-
make_efunc, retrieve_iteration_tree, Transformer
12+
KernelLaunch, Dereference, Lambda, ElementalFunction, CGen, FindSymbols,
13+
filter_iterations, make_efunc, retrieve_iteration_tree, Transformer
1414
)
1515
from devito.ir import SymbolRegistry
1616
from devito.passes.iet.engine import Graph
1717
from devito.passes.iet.languages.C import CDataManager
1818
from devito.symbolics import (Byref, FieldFromComposite, InlineIf, Macro, Class,
1919
String, FLOAT)
2020
from devito.tools import CustomDtype, as_tuple, dtype_to_ctype
21-
from devito.types import CustomDimension, Array, LocalObject, Symbol
21+
from devito.types import CustomDimension, Array, LocalObject, Symbol, Pointer
2222

2323

2424
@pytest.fixture
@@ -496,3 +496,16 @@ def test_list_inline():
496496

497497
lst = List(body=[expr0, expr1], inline=True)
498498
assert str(lst) == """a = 1; b = 2;"""
499+
500+
501+
def test_dereference_base_plus_off():
502+
ptr = Pointer(name='p', dtype=np.float32)
503+
off = Symbol(name='offs', dtype=np.int32)
504+
505+
dim0 = CustomDimension(name='d0', symbolic_size=2)
506+
dim1 = CustomDimension(name='d1', symbolic_size=3)
507+
x = Array(name='x', dimensions=(dim0, dim1), dtype=np.float32)
508+
509+
deref = Dereference(x, ptr, offset=off)
510+
511+
assert str(deref) == "float (*restrict x)[3] = (float (*)[3]) (p + offs);"

0 commit comments

Comments
 (0)