Skip to content

Commit f6e33ce

Browse files
authored
Merge pull request #219 from ToFuProject/Issue218_NonDirectVectorBasis
[#218] `_check_vectbasis(direct=bool)` implemented
2 parents 06c477f + 6da758b commit f6e33ce

File tree

1 file changed

+47
-14
lines changed

1 file changed

+47
-14
lines changed

datastock/_generic_check.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -420,14 +420,34 @@ def _check_vectbasis(
420420
e2=None,
421421
dim=None,
422422
tol=None,
423+
direct=None,
423424
):
425+
""" Check a 2d or 3d set of unit vectors
426+
427+
Check that:
428+
- vectors are defined (if None, they can be inferred)
429+
430+
Normalizes vectorto be unit vectors
431+
Optionally (default) check the vectors forma direct basis
432+
"""
433+
434+
# ----------------
435+
# check inputs
436+
# ----------------
424437

425438
# dim
426439
dim = _check_var(dim, 'dim', types=int, default=3, allowed=[2, 3])
427440

428441
# tol
429442
tol = _check_var(tol, 'tol', types=float, default=1.e-14, sign='>0.')
430443

444+
# direct
445+
direct = _check_var(direct, 'direct', types=bool, default=True)
446+
447+
# ----------------------
448+
# check what's provided
449+
# ----------------------
450+
431451
# check is provided
432452
if e0 is not None:
433453
e0 = _check_flat1darray(e0, 'e0', size=dim, dtype=float, norm=True)
@@ -436,13 +456,23 @@ def _check_vectbasis(
436456
if e2 is not None:
437457
e2 = _check_flat1darray(e2, 'e2', size=dim, dtype=float, norm=True)
438458

459+
# preliminary
460+
allnone = all([ee is None for ee in [e0, e1, e2][:dim]])
461+
if allnone is True:
462+
lstr = [f"\t- {ee}" for ee in ['e0', 'e1', 'e2'][:dim]]
463+
msg = (
464+
f"For a basis f dimension {dim}, provide at least one of:\n"
465+
+ "\n".join(lstr)
466+
)
467+
raise Exception(msg)
468+
469+
# ----------------------
470+
# dim = 2
471+
# ----------------------
472+
439473
# vectors
440474
if dim == 2:
441475

442-
if e0 is None and e1 is None:
443-
msg = "Please provide e0 and/or e1!"
444-
raise Exception(msg)
445-
446476
# complete if missing
447477
if e0 is None:
448478
e0 = np.r_[e1[1], -e1[0]]
@@ -455,24 +485,26 @@ def _check_vectbasis(
455485
raise Exception(msg)
456486

457487
# direct
458-
if np.abs(np.cross(e0, e1).tolist() - 1.) < tol:
459-
msg = "Non-direct basis"
460-
raise Exception(msg)
488+
if direct is True:
489+
if np.abs(np.cross(e0, e1).tolist() - 1.) < tol:
490+
msg = "Non-direct basis"
491+
raise Exception(msg)
461492

462493
return e0, e1
463494

495+
# ----------------------
496+
# dim = 3
497+
# ----------------------
498+
464499
else:
465-
if e0 is None and e1 is None and e2 is None:
466-
msg = "Please provide at least e0, e1 or e2!"
467-
raise Exception(msg)
468500

469501
# complete if 2 missing
470502
if e0 is None and e1 is None:
471503
e1 = _get_horizontal_unitvect(ee=e2)
472504
elif e0 is None and e2 is None:
473505
e2 = _get_vertical_unitvect(ee=e1)
474506
elif e1 is None and e2 is None:
475-
e2 = _get_vertical_unitvect(ee=e0)
507+
e2 = _get_horizontal_unitvect(ee=e0)
476508

477509
# complete if 1 missing
478510
if e0 is None:
@@ -502,9 +534,10 @@ def _check_vectbasis(
502534
raise Exception(msg)
503535

504536
# direct
505-
if not np.allclose(np.cross(e0, e1), e2, atol=tol, rtol=1e-6):
506-
msg = "Non-direct basis"
507-
raise Exception(msg)
537+
if direct is True:
538+
if not np.allclose(np.cross(e0, e1), e2, atol=tol, rtol=1e-6):
539+
msg = "Non-direct basis"
540+
raise Exception(msg)
508541

509542
return e0, e1, e2
510543

0 commit comments

Comments
 (0)