@@ -28,24 +28,27 @@ class TensorConstant:
28
28
class TensorOpInfo :
29
29
target : torch ._ops .OpOverload
30
30
use_schema_args : bool
31
+ use_self_dtype : bool
31
32
32
33
33
34
SCALAR_OPS = {
34
- aten .eq .Scalar : TensorOpInfo (aten .eq .Tensor , False ),
35
- aten .ge .Scalar : TensorOpInfo (aten .ge .Tensor , False ),
36
- aten .gt .Scalar : TensorOpInfo (aten .gt .Tensor , False ),
37
- aten .le .Scalar : TensorOpInfo (aten .le .Tensor , False ),
38
- aten .lt .Scalar : TensorOpInfo (aten .lt .Tensor , False ),
39
- aten .ne .Scalar : TensorOpInfo (aten .ne .Tensor , False ),
40
- aten .add .Scalar : TensorOpInfo (aten .add .Tensor , False ),
41
- aten .add_ .Scalar : TensorOpInfo (aten .add_ .Tensor , False ),
42
- aten .div .Scalar : TensorOpInfo (aten .div .Tensor , False ),
43
- aten .mul .Scalar : TensorOpInfo (aten .mul .Tensor , False ),
44
- aten .rsub .Scalar : TensorOpInfo (aten .rsub .Tensor , False ),
45
- aten .sub .Scalar : TensorOpInfo (aten .sub .Tensor , False ),
46
- aten .pow .Tensor_Scalar : TensorOpInfo (aten .pow .Tensor_Tensor , False ),
35
+ aten .eq .Scalar : TensorOpInfo (aten .eq .Tensor , False , False ),
36
+ aten .ge .Scalar : TensorOpInfo (aten .ge .Tensor , False , False ),
37
+ aten .gt .Scalar : TensorOpInfo (aten .gt .Tensor , False , False ),
38
+ aten .le .Scalar : TensorOpInfo (aten .le .Tensor , False , False ),
39
+ aten .lt .Scalar : TensorOpInfo (aten .lt .Tensor , False , False ),
40
+ aten .ne .Scalar : TensorOpInfo (aten .ne .Tensor , False , False ),
41
+ aten .add .Scalar : TensorOpInfo (aten .add .Tensor , False , False ),
42
+ aten .add_ .Scalar : TensorOpInfo (aten .add_ .Tensor , False , False ),
43
+ aten .div .Scalar : TensorOpInfo (aten .div .Tensor , False , False ),
44
+ aten .mul .Scalar : TensorOpInfo (aten .mul .Tensor , False , False ),
45
+ aten .rsub .Scalar : TensorOpInfo (aten .rsub .Tensor , False , False ),
46
+ aten .sub .Scalar : TensorOpInfo (aten .sub .Tensor , False , False ),
47
+ aten .pow .Tensor_Scalar : TensorOpInfo (aten .pow .Tensor_Tensor , False , False ),
47
48
# The scalar number arg[1] is missing when using default. Result in a corner case to deal
48
- aten .leaky_relu .default : TensorOpInfo (aten .prelu .default , True ),
49
+ aten .leaky_relu .default : TensorOpInfo (aten .prelu .default , True , False ),
50
+ aten .where .ScalarOther : TensorOpInfo (aten .where .self , False , True ),
51
+ aten .where .Scalar : TensorOpInfo (aten .where .self , False , True ),
49
52
}
50
53
51
54
@@ -63,11 +66,14 @@ def __init__(self):
63
66
def _build_tensor_constant (
64
67
self , gm : torch .fx .GraphModule , node : fx .Node , const_val
65
68
) -> TensorConstant :
69
+ # For dtype, in some cases, we cannot use node.args[0] as scalar dtype.
70
+ # Ex: Where op args[0] can be bool, however, we probably want args[1] and args[2] to be dtype same as node.meta["val"] instead of bool type
66
71
tensor = torch .tensor (
67
72
[const_val ],
68
73
dtype = (
69
74
node .args [0 ].meta ["val" ].dtype
70
75
if not is_float_tensor (node )
76
+ and not SCALAR_OPS .get (node .target ).use_self_dtype
71
77
else node .meta ["val" ].dtype
72
78
),
73
79
device = node .meta ["val" ].device ,
0 commit comments