Skip to content

Commit e02cd84

Browse files
committed
WIP Dense level
1 parent 48b9353 commit e02cd84

File tree

9 files changed

+147
-28
lines changed

9 files changed

+147
-28
lines changed

src/finchlite/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@
108108
DenseLevelFType,
109109
ElementLevelFType,
110110
FiberTensorFType,
111+
dense,
112+
element,
113+
fiber_tensor,
111114
)
112115

113116
__all__ = [
@@ -158,11 +161,14 @@
158161
"cos",
159162
"cosh",
160163
"defer",
164+
"dense",
161165
"dimension",
166+
"element",
162167
"element_type",
163168
"elementwise",
164169
"equal",
165170
"expand_dims",
171+
"fiber_tensor",
166172
"fill_value",
167173
"fisinstance",
168174
"flatten",

src/finchlite/codegen/numpy_buffer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ def __str__(self):
7676
arr_str = str(self.arr).replace("\n", "")
7777
return f"np_buf({arr_str})"
7878

79+
def __repr__(self):
80+
arr_repr = repr(self.arr).replace("\n", "")
81+
return f"NumpyBuffer({arr_repr})"
82+
7983

8084
class NumpyBufferFType(CBufferFType, NumbaBufferFType, CStackFType):
8185
"""

src/finchlite/interface/lazy.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -332,24 +332,24 @@ def __ne__(self, other):
332332
register_property(LazyTensor, "asarray", "__attr__", lambda x: x)
333333

334334

335-
def asarray(arg: Any, format="bufferized") -> Any:
335+
def asarray(arg: Any, format=None) -> Any:
336336
"""
337337
Convert given argument and return wrapper type instance.
338338
If input argument is already array type, return unchanged.
339339
340340
Args:
341341
arg: The object to be converted.
342+
format: The format for the result array.
342343
343344
Returns:
344345
The array type result of the given object.
345346
"""
346-
if format != "bufferized":
347-
raise Exception(f"Only bufferized format is now supported, got: {format}")
347+
if format is None:
348+
if hasattr(arg, "asarray"):
349+
return arg.asarray()
350+
return query_property(arg, "asarray", "__attr__")
348351

349-
if hasattr(arg, "asarray"):
350-
return arg.asarray()
351-
352-
return query_property(arg, "asarray", "__attr__")
352+
return format(arg)
353353

354354

355355
def defer(arr) -> LazyTensor:

src/finchlite/tensor/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .fiber_tensor import FiberTensor, FiberTensorFType, Level, LevelFType, tensor
1+
from .fiber_tensor import FiberTensor, FiberTensorFType, Level, LevelFType, fiber_tensor
22
from .level import (
33
DenseLevel,
44
DenseLevelFType,
@@ -19,5 +19,5 @@
1919
"LevelFType",
2020
"dense",
2121
"element",
22-
"tensor",
22+
"fiber_tensor",
2323
]

src/finchlite/tensor/fiber_tensor.py

Lines changed: 86 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
3-
from typing import Generic, TypeVar
3+
from typing import Generic, NamedTuple, TypeVar
44

5-
from finchlite.algebra import Tensor, TensorFType
6-
from finchlite.symbolic import FType, FTyped
5+
import numpy as np
6+
7+
from .. import finch_assembly as asm
8+
from .. import finch_notation as ntn
9+
from ..algebra import Tensor, register_property
10+
from ..codegen.numpy_buffer import NumpyBuffer
11+
from ..compile.lower import FinchTensorFType
12+
from ..symbolic import FType, FTyped
713

814

915
class LevelFType(FType, ABC):
@@ -161,8 +167,14 @@ def buffer_factory(self):
161167
return self.lvl.buffer_factory
162168

163169

170+
class FiberTensorFields(NamedTuple):
171+
stride: tuple[asm.Variable, ...]
172+
buf: asm.Variable
173+
buf_s: asm.Slot
174+
175+
164176
@dataclass(unsafe_hash=True)
165-
class FiberTensorFType(TensorFType):
177+
class FiberTensorFType(FinchTensorFType):
166178
"""
167179
An abstract base class representing the ftype of a fiber tensor.
168180
@@ -177,11 +189,14 @@ def __post_init__(self):
177189
if self._position_type is None:
178190
self._position_type = self.lvl.position_type
179191

180-
def __call__(self, shape):
192+
def __call__(self, *, shape=None, val=None):
181193
"""
182194
Creates an instance of a FiberTensor with the given arguments.
183195
"""
184-
return FiberTensor(self.lvl(shape), self.lvl.position_type(1))
196+
if shape is None:
197+
shape = val.shape
198+
val = NumpyBuffer(val.reshape[-1])
199+
return FiberTensor(self.lvl(shape, val), self.lvl.position_type(1))
185200

186201
@property
187202
def shape(self):
@@ -209,24 +224,82 @@ def position_type(self):
209224

210225
@property
211226
def buffer_factory(self):
212-
"""
213-
Returns the ftype of the buffer used for the fibers.
214-
This is typically a NumpyBufferFType or similar.
215-
"""
216227
return self.lvl.buffer_factory
217228

218-
219-
def tensor(lvl: LevelFType, position_type: type | None = None):
229+
def unfurl(self, ctx, tns, ext, mode, proto):
230+
op = None
231+
if isinstance(mode, ntn.Update):
232+
op = mode.op
233+
tns = ctx.resolve(tns).obj
234+
acc_t = self.lvl.unfurl(
235+
ctx,
236+
)
237+
238+
# acc_t = BufferizedNDArrayAccessorFType(self, 0, self.buf_t.length_type, op)
239+
# obj = BufferizedNDArrayAccessorFields(
240+
# tns, 0, asm.Literal(self.buf_t.length_type(0)), op
241+
# )
242+
return acc_t.unfurl(ctx, ntn.Stack(obj, acc_t), ext, mode, proto)
243+
244+
def lower_freeze(self, ctx, tns, op):
245+
raise NotImplementedError
246+
247+
def lower_thaw(self, ctx, tns, op):
248+
raise NotImplementedError
249+
250+
def lower_unwrap(self, ctx, obj):
251+
raise NotImplementedError
252+
253+
def lower_increment(self, ctx, obj, val):
254+
raise NotImplementedError
255+
256+
def lower_declare(self, ctx, tns, init, op, shape):
257+
i_var = asm.Variable("i", self.buffer_factory.length_type)
258+
body = asm.Store(
259+
tns.obj.buf_s,
260+
i_var,
261+
asm.Literal(init.val),
262+
)
263+
ctx.exec(
264+
asm.ForLoop(i_var, asm.Literal(np.intp(0)), asm.Length(tns.obj.buf_s), body)
265+
)
266+
return
267+
268+
def asm_unpack(self, ctx, var_n, val):
269+
"""
270+
Unpack the into asm context.
271+
"""
272+
stride = []
273+
shape_type = self.shape_type
274+
for i in range(self.ndim):
275+
stride_i = asm.Variable(f"{var_n}_stride_{i}", shape_type[i])
276+
stride.append(stride_i)
277+
stride_e = asm.GetAttr(val, asm.Literal("strides"))
278+
stride_i_e = asm.GetAttr(stride_e, asm.Literal(f"element_{i}"))
279+
ctx.exec(asm.Assign(stride_i, stride_i_e))
280+
buf = asm.Variable(f"{var_n}_buf", self.buffer_factory)
281+
buf_e = asm.GetAttr(val, asm.Literal("buf"))
282+
ctx.exec(asm.Assign(buf, buf_e))
283+
buf_s = asm.Slot(f"{var_n}_buf_slot", self.buffer_factory)
284+
ctx.exec(asm.Unpack(buf_s, buf))
285+
286+
return FiberTensorFields(tuple(stride), buf, buf_s)
287+
288+
289+
def fiber_tensor(lvl: LevelFType, position_type: type | None = None):
220290
"""
221291
Creates a FiberTensorFType with the given level ftype and position type.
222292
223293
Args:
224294
lvl: The level ftype to be used for the tensor.
225-
pos_type: The type of positions within the tensor. Defaults to None.
295+
position_type: The type of positions within the tensor. Defaults to None.
226296
227297
Returns:
228298
An instance of FiberTensorFType.
229299
"""
230300
# mypy does not understand that dataclasses generate __hash__ and __eq__
231301
# https://github.com/python/mypy/issues/19799
232302
return FiberTensorFType(lvl, position_type) # type: ignore[abstract]
303+
304+
305+
register_property(FiberTensor, "asarray", "__attr__", lambda x: x)

src/finchlite/tensor/level/dense_level.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66

7+
from ...compile import looplets as lplt
78
from ..fiber_tensor import Level, LevelFType
89

910

@@ -16,15 +17,15 @@ def __post_init__(self):
1617
if self.dimension_type is None:
1718
self.dimension_type = np.intp
1819

19-
def __call__(self, shape):
20+
def __call__(self, shape, val=None):
2021
"""
2122
Creates an instance of DenseLevel with the given ftype.
2223
Args:
2324
shape: The shape to be used for the level. (mandatory)
2425
Returns:
2526
An instance of DenseLevel.
2627
"""
27-
lvl = self.lvl(shape=shape[1:])
28+
lvl = self.lvl(shape[1:], val)
2829
return DenseLevel(self, lvl, self.dimension_type(shape[0]))
2930

3031
@property
@@ -63,6 +64,16 @@ def buffer_factory(self):
6364
"""
6465
return self.lvl.buffer_factory
6566

67+
def unfurl(self):
68+
def child_accessor(ctx, idx):
69+
self.lvl
70+
71+
return lplt.Lookup(
72+
body=lambda ctx, idx: lplt.Leaf(
73+
body=lambda ctx: child_accessor(ctx, idx),
74+
)
75+
)
76+
6677

6778
def dense(lvl, dimension_type=None):
6879
return DenseLevelFType(lvl, dimension_type=dimension_type)

src/finchlite/tensor/level/element_level.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55

66
from ...codegen import NumpyBufferFType
7+
from ...compile import looplets as lplt
78
from ...symbolic import FType, ftype
89
from ..fiber_tensor import Level, LevelFType
910

@@ -26,17 +27,18 @@ def __post_init__(self):
2627
self._element_type = self.val_format.element_type
2728
self._fill_value = self._element_type(self._fill_value)
2829

29-
def __call__(self, shape=()):
30+
def __call__(self, shape=(), val=None):
3031
"""
3132
Creates an instance of ElementLevel with the given ftype.
3233
Args:
33-
fmt: The ftype to be used for the level.
34+
shape: Should be always `()`, used for validation.
35+
val: The value to store in the ElementLevel instance.
3436
Returns:
3537
An instance of ElementLevel.
3638
"""
3739
if len(shape) != 0:
3840
raise ValueError("ElementLevelFType must be called with an empty shape.")
39-
return ElementLevel(self)
41+
return ElementLevel(self, val)
4042

4143
@property
4244
def ndim(self):
@@ -62,6 +64,16 @@ def shape_type(self):
6264
def buffer_factory(self):
6365
return self._buffer_factory
6466

67+
def unfurl(self):
68+
def child_accessor(ctx, idx):
69+
pass
70+
71+
return lplt.Lookup(
72+
body=lambda ctx, idx: lplt.Leaf(
73+
body=lambda ctx: child_accessor(ctx, idx),
74+
)
75+
)
76+
6577

6678
def element(
6779
fill_value=None,
@@ -78,6 +90,7 @@ def element(
7890
element_type: The type of elements stored in the level.
7991
position_type: The type of positions within the level.
8092
buffer_factory: The factory used to create buffers for the level.
93+
val_format: Format of the value stored in the level.
8194
8295
Returns:
8396
An instance of ElementLevelFType.

src/finchlite/tensor/level/sparse_list_level.py

Whitespace-only changes.

tests/test_tensor.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
DenseLevelFType,
55
ElementLevelFType,
66
FiberTensorFType,
7+
NumpyBuffer,
78
NumpyBufferFType,
9+
dense,
10+
element,
11+
fiber_tensor,
812
)
913

1014

1115
def test_fiber_tensor_attributes():
1216
fmt = FiberTensorFType(DenseLevelFType(ElementLevelFType(0.0)))
1317
shape = (3,)
14-
a = fmt(shape)
18+
a = fmt(shape=shape)
1519

1620
# Check shape attribute
1721
assert a.shape == shape
@@ -33,3 +37,11 @@ def test_fiber_tensor_attributes():
3337

3438
# Check buffer_format exists
3539
assert a.buffer_factory == NumpyBufferFType
40+
41+
42+
def test_fiber_tensor():
43+
fmt = fiber_tensor(
44+
dense(dense(element(np.int64(0), np.int64, np.intp, NumpyBufferFType)))
45+
)
46+
47+
fmt(shape=(3, 4), val=NumpyBuffer(np.arange(12)))

0 commit comments

Comments
 (0)