Skip to content

Commit 3551c2a

Browse files
committed
WIP Dense level
1 parent 48b9353 commit 3551c2a

File tree

10 files changed

+226
-33
lines changed

10 files changed

+226
-33
lines changed

src/finchlite/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@
108108
DenseLevelFType,
109109
ElementLevelFType,
110110
FiberTensorFType,
111+
dense,
112+
element,
113+
fiber_tensor,
111114
)
112115

113116
__all__ = [
@@ -158,11 +161,14 @@
158161
"cos",
159162
"cosh",
160163
"defer",
164+
"dense",
161165
"dimension",
166+
"element",
162167
"element_type",
163168
"elementwise",
164169
"equal",
165170
"expand_dims",
171+
"fiber_tensor",
166172
"fill_value",
167173
"fisinstance",
168174
"flatten",

src/finchlite/codegen/numpy_buffer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ def __str__(self):
7676
arr_str = str(self.arr).replace("\n", "")
7777
return f"np_buf({arr_str})"
7878

79+
def __repr__(self):
80+
arr_repr = repr(self.arr).replace("\n", "")
81+
return f"NumpyBuffer({arr_repr})"
82+
7983

8084
class NumpyBufferFType(CBufferFType, NumbaBufferFType, CStackFType):
8185
"""

src/finchlite/compile/bufferized_ndarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def asm_repack(self, ctx, lhs, obj):
377377
"""
378378
Repack the buffer from C context.
379379
"""
380-
(self.tns.asm_repack(ctx, lhs.tns, obj.tns),)
380+
self.tns.asm_repack(ctx, lhs.tns, obj.tns)
381381
ctx.exec(
382382
asm.Block(
383383
asm.SetAttr(lhs, "tns", obj.tns),

src/finchlite/interface/lazy.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -332,24 +332,24 @@ def __ne__(self, other):
332332
register_property(LazyTensor, "asarray", "__attr__", lambda x: x)
333333

334334

335-
def asarray(arg: Any, format="bufferized") -> Any:
335+
def asarray(arg: Any, format=None) -> Any:
336336
"""
337337
Convert given argument and return wrapper type instance.
338338
If input argument is already array type, return unchanged.
339339
340340
Args:
341341
arg: The object to be converted.
342+
format: The format for the result array.
342343
343344
Returns:
344345
The array type result of the given object.
345346
"""
346-
if format != "bufferized":
347-
raise Exception(f"Only bufferized format is now supported, got: {format}")
347+
if format is None:
348+
if hasattr(arg, "asarray"):
349+
return arg.asarray()
350+
return query_property(arg, "asarray", "__attr__")
348351

349-
if hasattr(arg, "asarray"):
350-
return arg.asarray()
351-
352-
return query_property(arg, "asarray", "__attr__")
352+
return format(arg)
353353

354354

355355
def defer(arr) -> LazyTensor:

src/finchlite/tensor/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .fiber_tensor import FiberTensor, FiberTensorFType, Level, LevelFType, tensor
1+
from .fiber_tensor import FiberTensor, FiberTensorFType, Level, LevelFType, fiber_tensor
22
from .level import (
33
DenseLevel,
44
DenseLevelFType,
@@ -19,5 +19,5 @@
1919
"LevelFType",
2020
"dense",
2121
"element",
22-
"tensor",
22+
"fiber_tensor",
2323
]

src/finchlite/tensor/fiber_tensor.py

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
3-
from typing import Generic, TypeVar
3+
from typing import Generic, NamedTuple, TypeVar
44

5-
from finchlite.algebra import Tensor, TensorFType
6-
from finchlite.symbolic import FType, FTyped
5+
import numpy as np
6+
7+
from .. import finch_assembly as asm
8+
from .. import finch_notation as ntn
9+
from ..algebra import Tensor, register_property
10+
from ..codegen.numpy_buffer import NumpyBuffer
11+
from ..compile.lower import FinchTensorFType
12+
from ..symbolic import FType, FTyped
713

814

915
class LevelFType(FType, ABC):
@@ -68,7 +74,7 @@ class Level(FTyped, ABC):
6874

6975
@property
7076
@abstractmethod
71-
def shape(self):
77+
def shape(self) -> tuple:
7278
"""
7379
Shape of the fibers in the structure.
7480
"""
@@ -161,8 +167,14 @@ def buffer_factory(self):
161167
return self.lvl.buffer_factory
162168

163169

170+
class FiberTensorFields(NamedTuple):
171+
stride: tuple[asm.Variable, ...]
172+
buf: asm.Variable
173+
buf_s: asm.Slot
174+
175+
164176
@dataclass(unsafe_hash=True)
165-
class FiberTensorFType(TensorFType):
177+
class FiberTensorFType(FinchTensorFType):
166178
"""
167179
An abstract base class representing the ftype of a fiber tensor.
168180
@@ -177,11 +189,17 @@ def __post_init__(self):
177189
if self._position_type is None:
178190
self._position_type = self.lvl.position_type
179191

180-
def __call__(self, shape):
192+
def __call__(self, *, shape=None, val=None):
181193
"""
182194
Creates an instance of a FiberTensor with the given arguments.
183195
"""
184-
return FiberTensor(self.lvl(shape), self.lvl.position_type(1))
196+
if shape is None:
197+
shape = val.shape
198+
val = NumpyBuffer(val.reshape[-1])
199+
return FiberTensor(self.lvl(shape, val), self.lvl.position_type(1))
200+
201+
def __str__(self):
202+
return f"FiberTensorFType({self.lvl})"
185203

186204
@property
187205
def shape(self):
@@ -209,24 +227,82 @@ def position_type(self):
209227

210228
@property
211229
def buffer_factory(self):
230+
return self.lvl.buffer_factory
231+
232+
def unfurl(self, ctx, tns, ext, mode, proto):
233+
op = None
234+
if isinstance(mode, ntn.Update):
235+
op = mode.op
236+
tns = ctx.resolve(tns).obj
237+
obj = self.lvl.get_fields_class(tns, 0, asm.Literal(self.position_type(0)), op)
238+
return self.lvl.unfurl(ctx, ntn.Stack(obj, self.lvl), ext, mode, proto)
239+
240+
def lower_freeze(self, ctx, tns, op):
241+
return tns
242+
243+
def lower_thaw(self, ctx, tns, op):
244+
raise NotImplementedError
245+
246+
def lower_unwrap(self, ctx, obj):
247+
raise NotImplementedError
248+
249+
def lower_increment(self, ctx, obj, val):
250+
raise NotImplementedError
251+
252+
def lower_declare(self, ctx, tns, init, op, shape):
253+
i_var = asm.Variable("i", self.buffer_factory.length_type)
254+
body = asm.Store(
255+
tns.obj.buf_s,
256+
i_var,
257+
asm.Literal(init.val),
258+
)
259+
ctx.exec(
260+
asm.ForLoop(i_var, asm.Literal(np.intp(0)), asm.Length(tns.obj.buf_s), body)
261+
)
262+
return
263+
264+
def asm_unpack(self, ctx, var_n, val):
212265
"""
213-
Returns the ftype of the buffer used for the fibers.
214-
This is typically a NumpyBufferFType or similar.
266+
Unpack the into asm context.
215267
"""
216-
return self.lvl.buffer_factory
268+
stride = []
269+
shape_type = self.shape_type
270+
for i in range(self.ndim):
271+
stride_i = asm.Variable(f"{var_n}_stride_{i}", shape_type[i])
272+
stride.append(stride_i)
273+
stride_e = asm.GetAttr(val, asm.Literal("strides"))
274+
stride_i_e = asm.GetAttr(stride_e, asm.Literal(f"element_{i}"))
275+
ctx.exec(asm.Assign(stride_i, stride_i_e))
276+
buf = asm.Variable(f"{var_n}_buf", self.buffer_factory)
277+
buf_e = asm.GetAttr(val, asm.Literal("buf"))
278+
ctx.exec(asm.Assign(buf, buf_e))
279+
buf_s = asm.Slot(f"{var_n}_buf_slot", self.buffer_factory)
280+
ctx.exec(asm.Unpack(buf_s, buf))
281+
282+
return FiberTensorFields(tuple(stride), buf, buf_s)
283+
284+
def asm_repack(self, ctx, lhs, obj):
285+
"""
286+
Repack the buffer from the context.
287+
"""
288+
ctx.exec(asm.Repack(obj.buf_s))
289+
return
217290

218291

219-
def tensor(lvl: LevelFType, position_type: type | None = None):
292+
def fiber_tensor(lvl: LevelFType, position_type: type | None = None):
220293
"""
221294
Creates a FiberTensorFType with the given level ftype and position type.
222295
223296
Args:
224297
lvl: The level ftype to be used for the tensor.
225-
pos_type: The type of positions within the tensor. Defaults to None.
298+
position_type: The type of positions within the tensor. Defaults to None.
226299
227300
Returns:
228301
An instance of FiberTensorFType.
229302
"""
230303
# mypy does not understand that dataclasses generate __hash__ and __eq__
231304
# https://github.com/python/mypy/issues/19799
232305
return FiberTensorFType(lvl, position_type) # type: ignore[abstract]
306+
307+
308+
register_property(FiberTensor, "asarray", "__attr__", lambda x: x)

src/finchlite/tensor/level/dense_level.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,48 @@
1+
import operator
12
from abc import ABC
23
from dataclasses import dataclass
3-
from typing import Any
4+
from typing import Any, NamedTuple
45

56
import numpy as np
67

7-
from ..fiber_tensor import Level, LevelFType
8+
from ... import finch_assembly as asm
9+
from ... import finch_notation as ntn
10+
from ...compile import looplets as lplt
11+
from ..fiber_tensor import FiberTensorFields, Level, LevelFType
12+
13+
14+
class DenseLevelFields(NamedTuple):
15+
tns: FiberTensorFields
16+
nind: int
17+
pos: asm.AssemblyNode
18+
op: Any
819

920

1021
@dataclass(unsafe_hash=True)
1122
class DenseLevelFType(LevelFType, ABC):
1223
lvl: Any
1324
dimension_type: Any = None
25+
pos: asm.AssemblyNode | None = None
26+
op: Any = None
1427

1528
def __post_init__(self):
1629
if self.dimension_type is None:
1730
self.dimension_type = np.intp
1831

19-
def __call__(self, shape):
32+
def __call__(self, shape, val=None):
2033
"""
2134
Creates an instance of DenseLevel with the given ftype.
2235
Args:
2336
shape: The shape to be used for the level. (mandatory)
2437
Returns:
2538
An instance of DenseLevel.
2639
"""
27-
lvl = self.lvl(shape=shape[1:])
40+
lvl = self.lvl(shape[1:], val)
2841
return DenseLevel(self, lvl, self.dimension_type(shape[0]))
2942

43+
def __str__(self):
44+
return f"DenseLevelFType({self.lvl})"
45+
3046
@property
3147
def ndim(self):
3248
return 1 + self.lvl.ndim
@@ -63,6 +79,45 @@ def buffer_factory(self):
6379
"""
6480
return self.lvl.buffer_factory
6581

82+
def get_fields_class(self, tns, nind, pos, op):
83+
return DenseLevelFields(tns, nind, pos, op)
84+
85+
def unfurl(self, ctx, tns, ext, mode, proto):
86+
def child_accessor(ctx, idx):
87+
pos_2 = asm.Variable(
88+
ctx.freshen(ctx.idx, f"_pos_{self.ndim - 1}"), self.pos
89+
)
90+
ctx.exec(
91+
asm.Assign(
92+
pos_2,
93+
asm.Call(
94+
asm.Literal(operator.add),
95+
[
96+
tns.obj.pos,
97+
asm.Call(
98+
asm.Literal(operator.mul),
99+
[
100+
tns.obj.tns.stride[tns.obj.nind],
101+
asm.Variable(ctx.idx.name, ctx.idx.type_),
102+
],
103+
),
104+
],
105+
),
106+
)
107+
)
108+
return ntn.Stack(
109+
self.lvl.get_fields_class(
110+
tns.obj.tns, tns.obj.nind + 1, pos_2, tns.obj.op
111+
),
112+
self.lvl,
113+
)
114+
115+
return lplt.Lookup(
116+
body=lambda ctx, idx: lplt.Leaf(
117+
body=lambda ctx: child_accessor(ctx, idx),
118+
)
119+
)
120+
66121

67122
def dense(lvl, dimension_type=None):
68123
return DenseLevelFType(lvl, dimension_type=dimension_type)
@@ -77,11 +132,15 @@ class DenseLevel(Level):
77132
_format: DenseLevelFType
78133
lvl: Any
79134
dimension: Any
135+
pos: asm.AssemblyNode | None = None
80136

81137
@property
82-
def shape(self):
138+
def shape(self) -> tuple:
83139
return (self.dimension, *self.lvl.shape)
84140

85141
@property
86142
def ftype(self):
87143
return self._format
144+
145+
def with_pos(self, pos: asm.AssemblyNode) -> "DenseLevel":
146+
return DenseLevel(self._format, self.lvl, self.dimension, pos)

0 commit comments

Comments
 (0)