Skip to content

Commit 1020466

Browse files
committed
TL: fixed tests
1 parent c4a99d4 commit 1020466

File tree

4 files changed

+32
-39
lines changed

4 files changed

+32
-39
lines changed

blockops/taskPool.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import sympy as sy
22
import re
33
import warnings
4-
import time
54

65
from blockops.run import PintRun
76

@@ -223,21 +222,22 @@ def addTask(self, ope, inp, dep, n, k, result=None):
223222
res : Expression
224223
The result of this task (full expression).
225224
"""
226-
tmpRes = ope * inp
227-
notZero = type(tmpRes) != sy.core.numbers.Zero
225+
tmpRes = sy.simplify(ope * inp)
228226

229-
if tmpRes in self.results and notZero:
227+
# Check if result is already computed
228+
notZero = type(tmpRes) != sy.core.numbers.Zero
229+
if notZero and tmpRes in self.results:
230230
task = self.results[tmpRes]
231231
return task, tmpRes
232232

233-
check2 = -ope * inp
234-
if check2 in self.results and notZero:
235-
task = self.results[check2]
233+
# Check if negative result is already computed
234+
neg = ope * sy.expand(-inp, deep=False)
235+
if notZero and neg in self.results:
236+
task = self.results[neg]
236237
return task, tmpRes
237-
238-
check3 = ope * -inp
239-
if check3 in self.results and notZero:
240-
task = self.results[check3]
238+
neg = sy.expand(-ope, deep=False) * inp
239+
if notZero and neg in self.results:
240+
task = self.results[neg]
241241
return task, tmpRes
242242

243243
# Task not in pool, create and add it

blockops/tests/test_blockMethods.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
from .data.methodData import results
3+
from blockops.tests.data.methodData import results
44

55
from blockops.run import PintRun
66
from blockops.block import BlockOperator, I
@@ -25,17 +25,16 @@
2525

2626

2727
def checkResults(method, run, pool):
28-
i = 0
29-
for key, value in run.facBlockRules.items():
28+
for i, (key, value) in enumerate(run.facBlockRules.items()):
3029
assert results[method]['blockRules'][i][0] == str(value["result"])
3130
assert results[method]['blockRules'][i][1] == str(value["rule"])
32-
i = i + 1
33-
34-
i = 0
35-
for key, value in pool.pool.items():
36-
assert results[method]['taskPool'][i][0] == str(key)
37-
assert str(value.fullOP) in results[method]['taskPool'][i][1]
38-
i = i + 1
31+
for i, (key, value) in enumerate(pool.pool.items()):
32+
refKey, refVal = results[method]['taskPool'][i]
33+
assert refKey == str(key)
34+
if isinstance(refVal, list):
35+
assert str(value.fullOP) in refVal
36+
else:
37+
assert str(value.fullOP) == refVal
3938

4039

4140
class TestMethods:

blockops/utils/params.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,20 +123,17 @@ def extractParamDocs(cls, *names):
123123

124124
if docs is None:
125125
raise ValueError(f'undocumented class {cls}')
126-
127126
for name in names:
128-
iStart = docs.find(f'\n {name} :')
127+
iStart = docs.find(f'{name} :')
129128
if iStart == -1:
130-
iStart = docs.find(f'\n **{name} :')
131-
if iStart == -1:
132-
raise ValueError(f'{name} parameter not in {cls} docs')
133-
docLines = docs[iStart:].splitlines()[2:]
129+
raise ValueError(f'{name} parameter not in {cls} docs')
130+
docLines = docs[iStart:].splitlines()[1:]
134131
descr = []
135132
for line in docLines:
136-
if line.startswith(8 * ' '):
137-
descr.append(line.strip())
138-
elif line.strip() == '':
133+
if line.strip() == "":
139134
continue
135+
elif line.startswith(" "):
136+
descr.append(line.strip())
140137
else:
141138
break
142139
if len(descr) == 0:

blockops/utils/vectorize.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ def matVecMul(mat, u):
3434
return np.matmul(mat, u[..., None]).squeeze(axis=-1)
3535

3636

37-
def matVecInv(mat, u):
37+
def matVecInv(mat, vec):
3838
r"""
3939
Compute vectorized Matrix Vector Inversion :math:`A^{-1}x` (A / x)
4040
4141
Parameters
4242
----------
4343
mat : np.ndarray, size (nDOF, M, M) or (M, M)
4444
Matrix or array of matrices.
45-
u : np.ndarray, size (nDOF, M) or (M,)
45+
vec : np.ndarray, size (nDOF, M) or (M,)
4646
Vector or array of vectors.
4747
4848
Returns
@@ -57,13 +57,10 @@ def matVecInv(mat, u):
5757
- matVecInv for (nDOF, M, M), (M,) -> (nDOF, M) <=> (M, M) \ (M,) for each nDOF
5858
- matVecInv for (M, M), (M,)) -> (M,) <=> (M, M) \ (M,)
5959
"""
60-
try:
61-
return np.linalg.solve(mat, u)
62-
except ValueError:
63-
try:
64-
return np.linalg.solve(mat[None, ...], u)
65-
except ValueError:
66-
return np.linalg.solve(mat, u[None, ...])
60+
mat, vec = np.asarray(mat), np.asarray(vec)
61+
if mat.ndim > 2 and vec.ndim > 1:
62+
assert mat.shape[0] == vec.shape[0], "different nDOF for mat and vec"
63+
return np.linalg.solve(mat, vec[..., None]).squeeze()
6764

6865

6966
def matMatMul(m1, m2):

0 commit comments

Comments
 (0)