Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Finch.jl
Submodule Finch.jl added at 8e4da9
4 changes: 4 additions & 0 deletions src/finchlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
add,
all,
any,
argmax,
argmin,
asarray,
asin,
asinh,
Expand Down Expand Up @@ -129,6 +131,8 @@
"add",
"all",
"any",
"argmax",
"argmin",
"asarray",
"asin",
"asinh",
Expand Down
9 changes: 9 additions & 0 deletions src/finchlite/algebra/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ def ifelse(a, b, c):
return a if c else b


# Assuming we are passing tuples which are values and its corresponding indices
def minby(a, b):
return a if a[0] <= b[0] else b


def maxby(a, b):
return a if a[0] >= b[0] else b


def promote_min(a, b):
cast = algebra.promote_type(a, b)
return cast(min(a, b))
Expand Down
4 changes: 3 additions & 1 deletion src/finchlite/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
vecdot,
)
from .fuse import compute, fuse, fused, get_default_scheduler, set_default_scheduler
from .lazy import LazyTensor, asarray, defer
from .lazy import LazyTensor, argmax, argmin, asarray, defer
from .scalar import Scalar, ScalarFType

__all__ = [
Expand All @@ -80,6 +80,8 @@
"add",
"all",
"any",
"argmax",
"argmin",
"asarray",
"asin",
"asinh",
Expand Down
64 changes: 64 additions & 0 deletions src/finchlite/interface/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,70 @@ def prod(
return reduce(operator.mul, x, axis=axis, dtype=dtype, keepdims=keepdims)


#######################################
class LazyIndices:
def __init__(self, shape):
self.shape = shape

def __getitem__(self, idxs):
flat_index = idxs[0]
for i in range(1, len(self.shape)):
flat_index = flat_index * self.shape[i] + idxs[i]
return flat_index


def last(tup):
return tup[-1]


def argmin(x, axis=None):
x = defer(x)
shape = x.shape

if axis is None:
indices = LazyIndices(shape)

else:
broadcast_indices = LazyTensor(
"i",
shape=(x.shape[axis],),
fill_value=x.fill_value,
element_type=x.element_type,
)
indices = expand_dims(
broadcast_indices, axis=[j for j in range(x.ndim) if j != axis]
)
# tuple(range(len(notin_a), len(notin_a) + len(axes_a)))
paired = elementwise(tuple, x, indices)
reduced = reduce(operator.minby, paired, axis=axis, init=(float("inf"), 0))

return elementwise(last, reduced)


def argmax(x, axis=None):
x = defer(x)
shape = x.shape

if axis is None:
indices = LazyIndices(shape)

else:
broadcast_indices = LazyTensor(
"i",
shape=(x.shape[axis],),
fill_value=x.fill_value,
element_type=x.element_type,
)
indices = expand_dims(
broadcast_indices, axis=[j for j in range(x.ndim) if j != axis]
)

paired = elementwise(tuple, x, indices)
reduced = reduce(operator.maxby, paired, axis=axis, init=(float("inf"), 0))

return elementwise(last, reduced)


def any(
x,
/,
Expand Down
4 changes: 4 additions & 0 deletions tests/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def test_unary_operations(a, a_wrap, ops, np_op):
finchlite.defer,
],
)
######################
@pytest.mark.parametrize(
"ops, np_op",
[
Expand All @@ -299,6 +300,8 @@ def test_unary_operations(a, a_wrap, ops, np_op):
((finchlite.all, np.all), np.all),
((finchlite.min, np.min), np.min),
((finchlite.max, np.max), np.max),
((finchlite.argmin, np.argmin), np.argmin),
((finchlite.argmax, np.argmax), np.argmax),
((finchlite.mean, np.mean), np.mean),
((finchlite.std, np.std), np.std),
((finchlite.var, np.var), np.var),
Expand All @@ -313,6 +316,7 @@ def test_unary_operations(a, a_wrap, ops, np_op):
(0, 1),
],
)
########################
def test_reduction_operations(a, a_wrap, ops, np_op, axis):
wa = a_wrap(a)

Expand Down