@@ -162,9 +162,15 @@ def aten_acosh(self: TFloat) -> TFloat:
162162
163163
164164@torch_op (("aten::add.Tensor" , "aten::add.Scalar" , "_operator::add" ), trace_only = True )
165- def aten_add (self : TReal , other : TReal , alpha : float = 1.0 ) -> TReal :
165+ def aten_add (self : TTensor , other : TTensor , alpha : float = 1.0 ) -> TTensor :
166166 """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
167- # TODO(microsoft/onnxruntime#15977): Improve fp16 precision
167+
168+ if self .dtype == ir .DataType .BOOL :
169+ # alpha can also be bool
170+ if alpha == 0 :
171+ return op .Identity (self )
172+ return op .Or (self , other )
173+
168174 if alpha != 1.0 :
169175 alpha = op .CastLike (alpha , other )
170176 other = op .Mul (other , alpha )
@@ -1233,15 +1239,19 @@ def aten_binomial(
12331239 "aten::bitwise_and.Tensor" ,
12341240 "aten::bitwise_and.Scalar" ,
12351241 "aten::bitwise_and.Scalar_Tensor" ,
1236- "_operator::and_" ,
12371242 ),
12381243 trace_only = True ,
12391244)
1240- def aten_bitwise_and (self : TInt , other : TInt ) -> TInt :
1245+ def aten_bitwise_and (self : TTensor , other : TTensor ) -> TTensor :
12411246 """bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor"""
1242- # logical_and implements the BOOL variant
12431247
1244- return op .BitwiseAnd (self , other )
1248+ assert self .dtype == other .dtype
1249+
1250+ if self .dtype .is_integer ():
1251+ return op .BitwiseAnd (self , other )
1252+ if self .dtype == ir .DataType .BOOL :
1253+ return op .And (self , other )
1254+ raise NotImplementedError (f"Not implemented for types { self .dtype } and { other .dtype } " )
12451255
12461256
12471257@torch_op (
@@ -1329,27 +1339,34 @@ def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8:
13291339
13301340
13311341@torch_op ("aten::bitwise_not" , trace_only = True )
1332- def aten_bitwise_not (self : TInt ) -> TInt :
1342+ def aten_bitwise_not (self : TTensor ) -> TTensor :
13331343 """bitwise_not(Tensor self) -> Tensor"""
1334- # logical_not implements the BOOL variant
13351344
1336- return op .BitwiseNot (self )
1345+ if self .dtype == ir .DataType .BOOL :
1346+ return op .Not (self )
1347+ if self .dtype .is_integer ():
1348+ return op .BitwiseNot (self )
1349+ raise NotImplementedError (f"Not implemented for type { self .dtype } " )
13371350
13381351
13391352@torch_op (
13401353 (
13411354 "aten::bitwise_or.Tensor" ,
13421355 "aten::bitwise_or.Scalar" ,
13431356 "aten::bitwise_or.Scalar_Tensor" ,
1344- "_operator::or_" ,
13451357 ),
13461358 trace_only = True ,
13471359)
1348- def aten_bitwise_or (self : TInt , other : TInt ) -> TInt :
1360+ def aten_bitwise_or (self : TTensor , other : TTensor ) -> TTensor :
13491361 """bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor"""
1350- # logical_or implements the BOOL variant
13511362
1352- return op .BitwiseOr (self , other )
1363+ assert self .dtype == other .dtype
1364+
1365+ if self .dtype .is_integer ():
1366+ return op .BitwiseOr (self , other )
1367+ if self .dtype == ir .DataType .BOOL :
1368+ return op .Or (self , other )
1369+ raise NotImplementedError (f"Not implemented for types { self .dtype } and { other .dtype } " )
13531370
13541371
13551372@torch_op (
@@ -1487,11 +1504,15 @@ def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8:
14871504 ),
14881505 trace_only = True ,
14891506)
1490- def aten_bitwise_xor (self : TInt , other : TInt ) -> TInt :
1507+ def aten_bitwise_xor (self : TTensor , other : TTensor ) -> TTensor :
14911508 """bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor"""
1492- # logical_xor implements the BOOL variant
1509+ assert self . dtype == other . dtype
14931510
1494- return op .BitwiseXor (self , other )
1511+ if self .dtype .is_integer ():
1512+ return op .BitwiseXor (self , other )
1513+ if self .dtype == ir .DataType .BOOL :
1514+ return op .Xor (self , other )
1515+ raise NotImplementedError (f"Not implemented for types { self .dtype } and { other .dtype } " )
14951516
14961517
14971518@torch_op ("aten::blackman_window" , trace_only = True )
@@ -5010,58 +5031,46 @@ def aten_logdet(self: TFloat) -> TFloat:
50105031 return op .Log (op .Det (self ))
50115032
50125033
5013- @torch_op (
5014- (
5015- "aten::logical_and" ,
5016- "aten::bitwise_and.Tensor" ,
5017- "aten::bitwise_and.Scalar" ,
5018- "aten::bitwise_and.Scalar_Tensor" ,
5019- ),
5020- trace_only = True ,
5021- )
5022- def aten_logical_and (self : BOOL , other : BOOL ) -> BOOL :
5034+ @torch_op ("aten::logical_and" , trace_only = True )
5035+ def aten_logical_and (self : TTensor , other : TTensor ) -> BOOL :
50235036 """logical_and(Tensor self, Tensor other) -> Tensor"""
50245037
5025- return op .And (self , other )
5038+ assert self .dtype == other .dtype
5039+
5040+ if self .dtype == ir .DataType .BOOL :
5041+ return op .And (self , other )
5042+ return op .And (op .Cast (self , to = BOOL .dtype ), op .Cast (other , to = BOOL .dtype ))
50265043
50275044
5028- @torch_op (( "aten::logical_not" , "aten::bitwise_not" ) , trace_only = True )
5029- def aten_logical_not (self : BOOL ) -> BOOL :
5045+ @torch_op ("aten::logical_not" , trace_only = True )
5046+ def aten_logical_not (self : TTensor ) -> BOOL :
50305047 """logical_not(Tensor self) -> Tensor"""
50315048
5032- return op .Not (self )
5049+ if self .dtype == ir .DataType .BOOL :
5050+ return op .Not (self )
5051+ return op .Not (op .Cast (self , to = BOOL .dtype ))
50335052
50345053
5035- @torch_op (
5036- (
5037- "aten::logical_or" ,
5038- "aten::bitwise_or.Tensor" ,
5039- "aten::bitwise_or.Scalar" ,
5040- "aten::bitwise_or.Scalar_Tensor" ,
5041- "aten::add.Tensor" ,
5042- "aten::add.Scalar" ,
5043- ),
5044- trace_only = True ,
5045- )
5046- def aten_logical_or (self : BOOL , other : BOOL ) -> BOOL :
5054+ @torch_op (("aten::logical_or" ), trace_only = True )
5055+ def aten_logical_or (self : TTensor , other : TTensor ) -> BOOL :
50475056 """logical_or(Tensor self, Tensor other) -> Tensor"""
50485057
5049- return op . Or ( self , other )
5058+ assert self . dtype == other . dtype
50505059
5060+ if self .dtype == ir .DataType .BOOL :
5061+ return op .Or (self , other )
5062+ return op .Or (op .Cast (self , to = BOOL .dtype ), op .Cast (other , to = BOOL .dtype ))
50515063
5052- @torch_op (
5053- (
5054- "aten::logical_xor" ,
5055- "aten::bitwise_xor.Tensor" ,
5056- "aten::bitwise_xor.Scalar" ,
5057- "aten::bitwise_xor.Scalar_Tensor" ,
5058- ),
5059- trace_only = True ,
5060- )
5061- def aten_logical_xor (self : BOOL , other : BOOL ) -> BOOL :
5064+
5065+ @torch_op ("aten::logical_xor" , trace_only = True )
5066+ def aten_logical_xor (self : TTensor , other : TTensor ) -> BOOL :
50625067 """logical_xor(Tensor self, Tensor other) -> Tensor"""
50635068
5064- return op .Xor (self , other )
5069+ assert self .dtype == other .dtype
5070+
5071+ if self .dtype == ir .DataType .BOOL :
5072+ return op .Xor (self , other )
5073+ return op .Xor (op .Cast (self , to = BOOL .dtype ), op .Cast (other , to = BOOL .dtype ))
50655074
50665075
50675076@torch_op ("aten::logit" , private = True )
0 commit comments