49
49
"c" : frozenset ({f"{ NP } .complex64" , f"{ NP } .complex128" , f"{ NP } .clongdouble" }),
50
50
}
51
51
52
- DATETIME_OPS : Final = {"+ " , "- " }
53
- TIMEDELTA_OPS : Final = DATETIME_OPS | {"* " , "/ " , "// " , "% " }
54
- BITWISE_OPS : Final = {"<< " , ">> " , "& " , "^ " , "| " }
52
+ DATETIME_OPS : Final = {"add " , "sub " }
53
+ TIMEDELTA_OPS : Final = DATETIME_OPS | {"mul " , "truediv " , "floordiv " , "mod" , "divmod " }
54
+ BITWISE_OPS : Final = {"lshift " , "rshift " , "and " , "or " , "xor " }
55
55
BITWISE_CHARS : Final = "?bhilqBHILQ"
56
56
57
57
INTP_EXPR : Final = f"{ NP } .{ np .intp .__name__ } "
@@ -659,35 +659,57 @@ class ScalarOps(TestGen):
659
659
}
660
660
661
661
ops : Final [dict [str , _BinOp ]]
662
+ names : Final [dict [str , str ]]
663
+ reject : Final [frozenset [str ]]
662
664
663
665
def __init__ (self , kind : _BinOpKind , / ) -> None :
666
+ reject : set [str ] = set ()
667
+ ignore : set [str ] = set ()
668
+
669
+ if kind in {"modular" , "bitwise" }:
670
+ # TODO(jorenham): Reject `inexact`
671
+ reject |= {"bhilBHILefdgFDG" , "FDG" , "F" , "D" , "G" }
672
+
673
+ # ignore the non-standard concrete complex (and float if bitwise) types
674
+ # to avoid generating many redundant rejection tests
675
+ ignore |= set ("FGM" )
676
+ if kind == "bitwise" :
677
+ ignore |= set ("efgm" )
678
+
664
679
match kind :
665
680
case "arithmetic" :
666
681
ops = self .OPS_ARITHMETIC
667
682
case "modular" :
668
683
ops = self .OPS_MODULAR
669
684
case "bitwise" :
670
685
ops = self .OPS_BITWISE
686
+ reject |= {"efdgFDG" , "efdg" , "e" , "f" , "d" , "g" }
671
687
case "comparison" :
672
688
ops = self .OPS_COMPARISON
673
689
674
690
self .ops = ops
691
+ self .names = {k : name for k , name in self .NAMES .items () if k not in ignore }
692
+ self .reject = frozenset (reject )
693
+
675
694
self .testname = self .testname .format (kind )
676
695
677
696
super ().__init__ ()
678
697
679
698
def _is_builtin (self , key : str , / ) -> bool :
680
- return len (key ) > 1 and self .NAMES [key ].endswith ("_py" )
699
+ return len (key ) > 1 and self .names [key ].endswith ("_py" )
681
700
682
701
def _is_abstract (self , key : str , / ) -> bool :
683
- return len (key ) > 1 and not self .NAMES [key ].endswith ("_py" )
702
+ return len (key ) > 1 and not self .names [key ].endswith ("_py" )
684
703
685
704
def _decompose (self , key : str , / ) -> tuple [_Scalar , ...]:
686
705
if not self ._is_abstract (key ):
687
706
return (_scalar (key ),)
688
707
return tuple (map (_scalar , key ))
689
708
690
709
def _evaluate_concrete (self , op : str , lhs : str , rhs : str , / ) -> str | None :
710
+ if lhs in self .reject or rhs in self .reject :
711
+ return None
712
+
691
713
fn = self .ops [op ]
692
714
nout = 2 if fn .__module__ == "builtins" else 1
693
715
@@ -722,26 +744,26 @@ def _evaluate_concrete(self, op: str, lhs: str, rhs: str, /) -> str | None:
722
744
return f"tuple[{ ', ' .join (result_exprs )} ]" if nout > 1 else result_exprs [0 ]
723
745
724
746
def _assert_stmt (self , op : str , lhs : str , rhs : str , / ) -> str | None :
725
- expr_eval = op .format (self .NAMES [lhs ], self .NAMES [rhs ])
726
- is_op = self .ops [op ].__module__ != "builtins" # not the case for divmod
747
+ expr_eval = op .format (self .names [lhs ], self .names [rhs ])
727
748
728
749
if not (expr_type := self ._evaluate_concrete (op , lhs , rhs )):
729
750
# generate rejection test, while avoiding trivial cases
751
+ opname = self .ops [op ].__name__ .removesuffix ("_" )
730
752
if (
731
- # ignore bitwise ops if either arg is not a bitwise char
732
- (op in BITWISE_OPS and {lhs , rhs } - set (BITWISE_CHARS ))
753
+ # ignore bitwise ops if neither arg is a bitwise char
754
+ (opname in BITWISE_OPS and not {lhs , rhs } & set (BITWISE_CHARS ))
733
755
# ignore if either arg is datetime and and not a datetime op
734
- or (op not in DATETIME_OPS and "M" in {lhs , rhs })
756
+ or (opname not in DATETIME_OPS and "M" in {lhs , rhs })
735
757
# ignore if either arg is timedelta and and not a timedelta op
736
- or (op not in TIMEDELTA_OPS and "m" in {lhs , rhs })
758
+ or (opname not in TIMEDELTA_OPS and "m" in {lhs , rhs })
737
759
):
738
760
return None
739
761
740
762
# pyright special casing
741
- if is_op :
742
- pyright_rules = ["OperatorIssue" ]
743
- else :
763
+ if opname == "divmod" :
744
764
pyright_rules = ["ArgumentType" , "CallIssue" ]
765
+ else :
766
+ pyright_rules = ["OperatorIssue" ]
745
767
pyright_ignore = ", " .join (map ("report{}" .format , pyright_rules ))
746
768
747
769
return " " .join ((
@@ -769,7 +791,7 @@ def _assert_stmt(self, op: str, lhs: str, rhs: str, /) -> str | None:
769
791
if (
770
792
op in self .OPS_ARITHMETIC | self .OPS_MODULAR
771
793
and lhs == rhs
772
- and (abstract_arg := self .ABSTRACT_TYPES .get (self .NAMES [lhs ]))
794
+ and (abstract_arg := self .ABSTRACT_TYPES .get (self .names [lhs ]))
773
795
):
774
796
if abstract_arg == "integer" and " / " not in op :
775
797
mypy_ignore = "assert-type, operator"
@@ -790,19 +812,19 @@ def _generate_names_section(self) -> Generator[str]:
790
812
@override
791
813
def get_names (self ) -> Iterable [tuple [str , str ]]:
792
814
# builtin scalars
793
- for builtin , name in self .NAMES .items ():
815
+ for builtin , name in self .names .items ():
794
816
if self ._is_builtin (builtin ):
795
817
yield name , builtin
796
818
797
819
# constrete numpy scalars
798
820
yield "" , ""
799
- for char , name in self .NAMES .items ():
821
+ for char , name in self .names .items ():
800
822
if len (char ) == 1 :
801
823
yield name , _sctype_expr (np .dtype (char ))
802
824
803
825
# abstract numpy scalars
804
826
yield "" , ""
805
- for char , kind in self .NAMES .items ():
827
+ for char , kind in self .names .items ():
806
828
if self ._is_abstract (char ):
807
829
yield kind , f"{ NP } .{ self .ABSTRACT_TYPES [kind ]} "
808
830
@@ -813,13 +835,13 @@ def get_testcases(self) -> Iterable[str | None]:
813
835
814
836
yield from self ._generate_section (f"__[r]{ opname } __" )
815
837
816
- for lhs in self .NAMES :
838
+ for lhs in self .names :
817
839
if self ._is_builtin (lhs ):
818
840
# will cause false positives on pyright; as designed, of course
819
841
continue
820
842
821
843
n = 0
822
- for rhs in self .NAMES :
844
+ for rhs in self .names :
823
845
if fn .__name__ in {"eq" , "ne" } and self ._is_abstract (rhs ):
824
846
# will be inferred by mypy as `Any` for some reason
825
847
continue
0 commit comments