Skip to content

Commit d80e3a6

Browse files
committed
[SYSTEMDS-3758] Python API Builtin triu, tril, argmin, argmax and casting Scalar <-> Matrix <-> Frame
Closes #2113
1 parent 504e751 commit d80e3a6

File tree

8 files changed

+406
-3
lines changed

8 files changed

+406
-3
lines changed

src/main/python/systemds/operator/nodes/frame.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,14 @@
4444
from systemds.context import SystemDSContext
4545

4646

47-
class Frame(OperationNode):
47+
def to_frame(self):
48+
return Frame(self.sds_context, "as.frame", [self])
49+
50+
51+
OperationNode.to_frame = to_frame
4852

53+
54+
class Frame(OperationNode):
4955
_pd_dataframe: pd.DataFrame
5056

5157
def __init__(

src/main/python/systemds/operator/nodes/matrix.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@
4141
)
4242

4343

44+
def to_matrix(self):
45+
return Matrix(self.sds_context, "as.matrix", [self])
46+
47+
48+
OperationNode.to_matrix = to_matrix
49+
50+
4451
class Matrix(OperationNode):
4552
_np_array: np.array
4653

@@ -842,5 +849,89 @@ def ifft(self, imag_input: "Matrix" = None) -> "MultiReturn":
842849

843850
return ifft_node
844851

852+
def triu(self, include_diagonal=True, return_values=True) -> "Matrix":
853+
"""Selects the upper triangular part of a matrix, configurable to include the diagonal and return values or ones
854+
855+
:param include_diagonal: boolean, default True
856+
:param return_values: boolean, default True, if set to False returns ones
857+
:return: `Matrix`
858+
"""
859+
named_input_nodes = {
860+
"target": self,
861+
"diag": self.sds_context.scalar(include_diagonal),
862+
"values": self.sds_context.scalar(return_values),
863+
}
864+
return Matrix(
865+
self.sds_context, "upper.tri", named_input_nodes=named_input_nodes
866+
)
867+
868+
def tril(self, include_diagonal=True, return_values=True) -> "Matrix":
869+
"""Selects the lower triangular part of a matrix, configurable to include the diagonal and return values or ones
870+
871+
:param include_diagonal: boolean, default True
872+
:param return_values: boolean, default True, if set to False returns ones
873+
:return: `Matrix`
874+
"""
875+
named_input_nodes = {
876+
"target": self,
877+
"diag": self.sds_context.scalar(include_diagonal),
878+
"values": self.sds_context.scalar(return_values),
879+
}
880+
return Matrix(
881+
self.sds_context, "lower.tri", named_input_nodes=named_input_nodes
882+
)
883+
884+
def argmin(self, axis: int = None) -> "OperationNode":
885+
"""Return the index of the minimum if axis is None or a column vector for row-wise / column-wise minima
886+
computation.
887+
888+
:param axis: can be 0 or 1 to do either row or column sums
889+
:return: `Matrix` representing operation for row / columns or 'Scalar' representing operation for complete
890+
"""
891+
if axis == 0:
892+
return Matrix(self.sds_context, "rowIndexMin", [self.t()])
893+
elif axis == 1:
894+
return Matrix(self.sds_context, "rowIndexMin", [self])
895+
elif axis is None:
896+
return Matrix(
897+
self.sds_context,
898+
"rowIndexMin",
899+
[self.reshape(1, self.nCol() * self.nRow())],
900+
).to_scalar()
901+
else:
902+
raise ValueError(
903+
f"Axis has to be either 0, 1 or None, for column, row or complete {self.operation}"
904+
)
905+
906+
def argmax(self, axis: int = None) -> "OperationNode":
907+
"""Return the index of the maximum if axis is None or a column vector for row-wise / column-wise maxima
908+
computation.
909+
910+
:param axis: can be 0 or 1 to do either row or column sums
911+
:return: `Matrix` representing operation for row / columns or 'Scalar' representing operation for complete
912+
"""
913+
if axis == 0:
914+
return Matrix(self.sds_context, "rowIndexMax", [self.t()])
915+
elif axis == 1:
916+
return Matrix(self.sds_context, "rowIndexMax", [self])
917+
elif axis is None:
918+
return Matrix(
919+
self.sds_context,
920+
"rowIndexMax",
921+
[self.reshape(1, self.nCol() * self.nRow())],
922+
).to_scalar()
923+
else:
924+
raise ValueError(
925+
f"Axis has to be either 0, 1 or None, for column, row or complete {self.operation}"
926+
)
927+
928+
def reshape(self, rows, cols=1):
929+
"""Gives a new shape to a matrix without changing its data.
930+
931+
:param rows: number of rows
932+
:param cols: number of columns, defaults to 1
933+
:return: `Matrix` representing operation"""
934+
return Matrix(self.sds_context, "matrix", [self, rows, cols])
935+
845936
def __str__(self):
846937
return "MatrixNode"

src/main/python/systemds/operator/nodes/scalar.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@
3232
VALID_ARITHMETIC_TYPES,
3333
VALID_INPUT_TYPES,
3434
)
35-
from systemds.utils.converters import numpy_to_matrix_block
35+
36+
37+
def to_scalar(self):
38+
return Scalar(self.sds_context, "as.scalar", [self])
39+
40+
41+
OperationNode.to_scalar = to_scalar
3642

3743

3844
class Scalar(OperationNode):
@@ -67,6 +73,8 @@ def code_line(
6773
named_input_vars: Dict[str, str],
6874
) -> str:
6975
if self.__assign:
76+
if type(self.operation) is bool:
77+
self.operation = "TRUE" if self.operation else "FALSE"
7078
return f"{var_name}={self.operation};"
7179
else:
7280
return super().code_line(var_name, unnamed_input_vars, named_input_vars)
@@ -289,6 +297,20 @@ def to_string(self, **kwargs: Dict[str, VALID_INPUT_TYPES]) -> "Scalar":
289297
"""
290298
return Scalar(self.sds_context, "toString", [self], named_input_nodes=kwargs)
291299

300+
def to_int(self) -> "Scalar":
301+
return Scalar(
302+
self.sds_context,
303+
"as.integer",
304+
[self],
305+
)
306+
307+
def to_boolean(self) -> "Scalar":
308+
return Scalar(
309+
self.sds_context,
310+
"as.logical",
311+
[self],
312+
)
313+
292314
def isNA(self) -> "Scalar":
293315
"""Computes a boolean indicator matrix of the same shape as the input, indicating where NA (not available)
294316
values are located. Currently, NA is only capturing NaN values.

src/main/python/systemds/operator/operation_node.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,12 @@ def print(self, **kwargs: Dict[str, VALID_INPUT_TYPES]) -> "OperationNode":
202202
To get the returned string look at the stdout of SystemDSContext.
203203
"""
204204
return OperationNode(self.sds_context, "print", [self], kwargs)
205+
206+
def to_frame(self):
207+
raise NotImplementedError("should have been overwritten in frame.py")
208+
209+
def to_matrix(self):
210+
raise NotImplementedError("should have been overwritten in matrix.py")
211+
212+
def to_scalar(self):
213+
raise NotImplementedError("should have been overwritten in scalar.py")

src/main/python/systemds/utils/converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def pandas_to_frame_block(sds, pd_df: pd.DataFrame):
104104
np.dtype(np.int32): jvm.org.apache.sysds.common.Types.ValueType.INT32,
105105
np.dtype(np.float32): jvm.org.apache.sysds.common.Types.ValueType.FP32,
106106
np.dtype(np.uint8): jvm.org.apache.sysds.common.Types.ValueType.UINT8,
107-
np.dtype(np.character): jvm.org.apache.sysds.common.Types.ValueType.CHARACTER,
107+
np.dtype(np.str_): jvm.org.apache.sysds.common.Types.ValueType.CHARACTER,
108108
}
109109
schema = []
110110
col_names = []
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# -------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
# -------------------------------------------------------------
21+
22+
import unittest
23+
import numpy as np
24+
from systemds.context import SystemDSContext
25+
26+
np.random.seed(7)
27+
m = np.array([[1, 2, 3], [6, 5, 4], [8, 7, 9]])
28+
M = np.random.random_integers(9, size=300).reshape(100, 3)
29+
p = np.array([0.25, 0.5, 0.75])
30+
m2 = np.array([1, 2, 3, 4, 5])
31+
w2 = np.array([1, 1, 1, 1, 5])
32+
33+
34+
def weighted_quantiles(values, weights, quantiles=0.5):
35+
i = np.argsort(values)
36+
c = np.cumsum(weights[i])
37+
return values[i[np.searchsorted(c, np.array(quantiles) * c[-1])]]
38+
39+
40+
class TestARGMINMAX(unittest.TestCase):
41+
def setUp(self):
42+
self.sds = SystemDSContext()
43+
44+
def tearDown(self):
45+
self.sds.close()
46+
47+
def test_argmin_basic1(self):
48+
sds_input = self.sds.from_numpy(m)
49+
sds_result = sds_input.argmin(0).compute()
50+
np_result = np.argmin(m, axis=0).reshape(-1, 1)
51+
assert np.allclose(sds_result - 1, np_result, 1e-9)
52+
53+
def test_argmin_basic2(self):
54+
sds_input = self.sds.from_numpy(m)
55+
sds_result = sds_input.argmin(1).compute()
56+
np_result = np.argmin(m, axis=1).reshape(-1, 1)
57+
assert np.allclose(sds_result - 1, np_result, 1e-9)
58+
59+
def test_argmin_basic3(self):
60+
sds_input = self.sds.from_numpy(m)
61+
sds_result = sds_input.argmin().compute(verbose=True)
62+
np_result = np.argmin(m)
63+
assert np.allclose(sds_result - 1, np_result, 1e-9)
64+
65+
def test_argmin_basic4(self):
66+
sds_input = self.sds.from_numpy(m)
67+
with self.assertRaises(ValueError):
68+
sds_input.argmin(3)
69+
70+
def test_argmax_basic1(self):
71+
sds_input = self.sds.from_numpy(m)
72+
sds_result = sds_input.argmax(0).compute()
73+
np_result = np.argmax(m, axis=0).reshape(-1, 1)
74+
assert np.allclose(sds_result - 1, np_result, 1e-9)
75+
76+
def test_argmax_basic2(self):
77+
sds_input = self.sds.from_numpy(m)
78+
sds_result = sds_input.argmax(1).compute()
79+
np_result = np.argmax(m, axis=1).reshape(-1, 1)
80+
assert np.allclose(sds_result - 1, np_result, 1e-9)
81+
82+
def test_argmax_basic3(self):
83+
sds_input = self.sds.from_numpy(m)
84+
sds_result = sds_input.argmax().compute()
85+
np_result = np.argmax(m)
86+
assert np.allclose(sds_result - 1, np_result, 1e-9)
87+
88+
def test_argmax_basic4(self):
89+
sds_input = self.sds.from_numpy(m)
90+
with self.assertRaises(ValueError):
91+
sds_input.argmax(3)
92+
93+
94+
if __name__ == "__main__":
95+
unittest.main()
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# -------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
# -------------------------------------------------------------
21+
22+
import unittest
23+
import numpy as np
24+
from systemds.context import SystemDSContext
25+
from pandas import DataFrame
26+
from numpy import ndarray
27+
28+
29+
class TestDIAG(unittest.TestCase):
30+
def setUp(self):
31+
self.sds = SystemDSContext()
32+
33+
def tearDown(self):
34+
self.sds.close()
35+
36+
def test_casting_basic1(self):
37+
sds_input = self.sds.from_numpy(np.array([[1]]))
38+
sds_result = sds_input.to_scalar().compute()
39+
self.assertTrue(type(sds_result) == float)
40+
41+
def test_casting_basic2(self):
42+
sds_input = self.sds.from_numpy(np.array([[1]]))
43+
sds_result = sds_input.to_frame().compute()
44+
self.assertTrue(type(sds_result) == DataFrame)
45+
46+
def test_casting_basic3(self):
47+
sds_result = self.sds.scalar(1.0).to_frame().compute()
48+
self.assertTrue(type(sds_result) == DataFrame)
49+
50+
def test_casting_basic4(self):
51+
sds_result = self.sds.scalar(1.0).to_matrix().compute()
52+
self.assertTrue(type(sds_result) == ndarray)
53+
54+
def test_casting_basic5(self):
55+
ar = ndarray((2, 2))
56+
df = DataFrame(ar)
57+
sds_result = self.sds.from_pandas(df).to_matrix().compute()
58+
self.assertTrue(type(sds_result) == ndarray and np.allclose(ar, sds_result))
59+
60+
def test_casting_basic6(self):
61+
ar = ndarray((1, 1))
62+
df = DataFrame(ar)
63+
sds_result = self.sds.from_pandas(df).to_scalar().compute()
64+
self.assertTrue(type(sds_result) == float)
65+
66+
def test_casting_basic7(self):
67+
sds_result = self.sds.scalar(1.0).to_int().compute()
68+
self.assertTrue(type(sds_result) == int and sds_result)
69+
70+
def test_casting_basic8(self):
71+
sds_result = self.sds.scalar(1.0).to_boolean().compute()
72+
self.assertTrue(type(sds_result) == bool)
73+
74+
75+
if __name__ == "__main__":
76+
unittest.main()

0 commit comments

Comments
 (0)