@@ -420,14 +420,34 @@ def _check_vectbasis(
420420 e2 = None ,
421421 dim = None ,
422422 tol = None ,
423+ direct = None ,
423424):
425+ """ Check a 2d or 3d set of unit vectors
426+
427+ Check that:
428+ - vectors are defined (if None, they can be inferred)
429+
430+ Normalizes vectorto be unit vectors
431+ Optionally (default) check the vectors forma direct basis
432+ """
433+
434+ # ----------------
435+ # check inputs
436+ # ----------------
424437
425438 # dim
426439 dim = _check_var (dim , 'dim' , types = int , default = 3 , allowed = [2 , 3 ])
427440
428441 # tol
429442 tol = _check_var (tol , 'tol' , types = float , default = 1.e-14 , sign = '>0.' )
430443
444+ # direct
445+ direct = _check_var (direct , 'direct' , types = bool , default = True )
446+
447+ # ----------------------
448+ # check what's provided
449+ # ----------------------
450+
431451 # check is provided
432452 if e0 is not None :
433453 e0 = _check_flat1darray (e0 , 'e0' , size = dim , dtype = float , norm = True )
@@ -436,13 +456,23 @@ def _check_vectbasis(
436456 if e2 is not None :
437457 e2 = _check_flat1darray (e2 , 'e2' , size = dim , dtype = float , norm = True )
438458
459+ # preliminary
460+ allnone = all ([ee is None for ee in [e0 , e1 , e2 ][:dim ]])
461+ if allnone is True :
462+ lstr = [f"\t - { ee } " for ee in ['e0' , 'e1' , 'e2' ][:dim ]]
463+ msg = (
464+ f"For a basis f dimension { dim } , provide at least one of:\n "
465+ + "\n " .join (lstr )
466+ )
467+ raise Exception (msg )
468+
469+ # ----------------------
470+ # dim = 2
471+ # ----------------------
472+
439473 # vectors
440474 if dim == 2 :
441475
442- if e0 is None and e1 is None :
443- msg = "Please provide e0 and/or e1!"
444- raise Exception (msg )
445-
446476 # complete if missing
447477 if e0 is None :
448478 e0 = np .r_ [e1 [1 ], - e1 [0 ]]
@@ -455,24 +485,26 @@ def _check_vectbasis(
455485 raise Exception (msg )
456486
457487 # direct
458- if np .abs (np .cross (e0 , e1 ).tolist () - 1. ) < tol :
459- msg = "Non-direct basis"
460- raise Exception (msg )
488+ if direct is True :
489+ if np .abs (np .cross (e0 , e1 ).tolist () - 1. ) < tol :
490+ msg = "Non-direct basis"
491+ raise Exception (msg )
461492
462493 return e0 , e1
463494
495+ # ----------------------
496+ # dim = 3
497+ # ----------------------
498+
464499 else :
465- if e0 is None and e1 is None and e2 is None :
466- msg = "Please provide at least e0, e1 or e2!"
467- raise Exception (msg )
468500
469501 # complete if 2 missing
470502 if e0 is None and e1 is None :
471503 e1 = _get_horizontal_unitvect (ee = e2 )
472504 elif e0 is None and e2 is None :
473505 e2 = _get_vertical_unitvect (ee = e1 )
474506 elif e1 is None and e2 is None :
475- e2 = _get_vertical_unitvect (ee = e0 )
507+ e2 = _get_horizontal_unitvect (ee = e0 )
476508
477509 # complete if 1 missing
478510 if e0 is None :
@@ -502,9 +534,10 @@ def _check_vectbasis(
502534 raise Exception (msg )
503535
504536 # direct
505- if not np .allclose (np .cross (e0 , e1 ), e2 , atol = tol , rtol = 1e-6 ):
506- msg = "Non-direct basis"
507- raise Exception (msg )
537+ if direct is True :
538+ if not np .allclose (np .cross (e0 , e1 ), e2 , atol = tol , rtol = 1e-6 ):
539+ msg = "Non-direct basis"
540+ raise Exception (msg )
508541
509542 return e0 , e1 , e2
510543
0 commit comments