@@ -28,24 +28,27 @@ class TensorConstant:
2828class TensorOpInfo :
2929 target : torch ._ops .OpOverload
3030 use_schema_args : bool
31+ use_self_dtype : bool
3132
3233
3334SCALAR_OPS = {
34- aten .eq .Scalar : TensorOpInfo (aten .eq .Tensor , False ),
35- aten .ge .Scalar : TensorOpInfo (aten .ge .Tensor , False ),
36- aten .gt .Scalar : TensorOpInfo (aten .gt .Tensor , False ),
37- aten .le .Scalar : TensorOpInfo (aten .le .Tensor , False ),
38- aten .lt .Scalar : TensorOpInfo (aten .lt .Tensor , False ),
39- aten .ne .Scalar : TensorOpInfo (aten .ne .Tensor , False ),
40- aten .add .Scalar : TensorOpInfo (aten .add .Tensor , False ),
41- aten .add_ .Scalar : TensorOpInfo (aten .add_ .Tensor , False ),
42- aten .div .Scalar : TensorOpInfo (aten .div .Tensor , False ),
43- aten .mul .Scalar : TensorOpInfo (aten .mul .Tensor , False ),
44- aten .rsub .Scalar : TensorOpInfo (aten .rsub .Tensor , False ),
45- aten .sub .Scalar : TensorOpInfo (aten .sub .Tensor , False ),
46- aten .pow .Tensor_Scalar : TensorOpInfo (aten .pow .Tensor_Tensor , False ),
35+ aten .eq .Scalar : TensorOpInfo (aten .eq .Tensor , False , False ),
36+ aten .ge .Scalar : TensorOpInfo (aten .ge .Tensor , False , False ),
37+ aten .gt .Scalar : TensorOpInfo (aten .gt .Tensor , False , False ),
38+ aten .le .Scalar : TensorOpInfo (aten .le .Tensor , False , False ),
39+ aten .lt .Scalar : TensorOpInfo (aten .lt .Tensor , False , False ),
40+ aten .ne .Scalar : TensorOpInfo (aten .ne .Tensor , False , False ),
41+ aten .add .Scalar : TensorOpInfo (aten .add .Tensor , False , False ),
42+ aten .add_ .Scalar : TensorOpInfo (aten .add_ .Tensor , False , False ),
43+ aten .div .Scalar : TensorOpInfo (aten .div .Tensor , False , False ),
44+ aten .mul .Scalar : TensorOpInfo (aten .mul .Tensor , False , False ),
45+ aten .rsub .Scalar : TensorOpInfo (aten .rsub .Tensor , False , False ),
46+ aten .sub .Scalar : TensorOpInfo (aten .sub .Tensor , False , False ),
47+ aten .pow .Tensor_Scalar : TensorOpInfo (aten .pow .Tensor_Tensor , False , False ),
4748 # The scalar number arg[1] is missing when using default. Result in a corner case to deal
48- aten .leaky_relu .default : TensorOpInfo (aten .prelu .default , True ),
49+ aten .leaky_relu .default : TensorOpInfo (aten .prelu .default , True , False ),
50+ aten .where .ScalarOther : TensorOpInfo (aten .where .self , False , True ),
51+ aten .where .Scalar : TensorOpInfo (aten .where .self , False , True ),
4952}
5053
5154
@@ -63,11 +66,14 @@ def __init__(self):
6366 def _build_tensor_constant (
6467 self , gm : torch .fx .GraphModule , node : fx .Node , const_val
6568 ) -> TensorConstant :
69+ # For dtype, in some cases, we cannot use node.args[0] as scalar dtype.
70+ # Ex: Where op args[0] can be bool, however, we probably want args[1] and args[2] to be dtype same as node.meta["val"] instead of bool type
6671 tensor = torch .tensor (
6772 [const_val ],
6873 dtype = (
6974 node .args [0 ].meta ["val" ].dtype
7075 if not is_float_tensor (node )
76+ and not SCALAR_OPS .get (node .target ).use_self_dtype
7177 else node .meta ["val" ].dtype
7278 ),
7379 device = node .meta ["val" ].device ,
0 commit comments