File tree Expand file tree Collapse file tree 1 file changed +9
-4
lines changed Expand file tree Collapse file tree 1 file changed +9
-4
lines changed Original file line number Diff line number Diff line change @@ -151,13 +151,18 @@ def result_type(
151
151
return _reduce (_result_type , others + scalars )
152
152
153
153
154
- def _result_type (x , y ):
154
+ def _result_type (
155
+ x : Array | DType | bool | int | float | complex ,
156
+ y : Array | DType | bool | int | float | complex ,
157
+ ) -> DType :
155
158
if not (isinstance (x , _py_scalars ) or isinstance (y , _py_scalars )):
156
- xdt = x . dtype if not isinstance (x , torch .dtype ) else x
157
- ydt = y . dtype if not isinstance (y , torch .dtype ) else y
159
+ xdt = x if isinstance (x , torch .dtype ) else x . dtype
160
+ ydt = y if isinstance (y , torch .dtype ) else y . dtype
158
161
159
- if ( xdt , ydt ) in _promotion_table :
162
+ try :
160
163
return _promotion_table [xdt , ydt ]
164
+ except KeyError :
165
+ pass
161
166
162
167
# This doesn't result_type(dtype, dtype) for non-array API dtypes
163
168
# because torch.result_type only accepts tensors. This does however, allow
You can’t perform that action at this time.
0 commit comments