@@ -34,9 +34,15 @@ def get_floor_divide_decomposition(op) -> tuple:
3434 """
3535
3636 if op in edge_floor_divide_ops :
37- return (exir_ops .edge .aten .div .Tensor_mode ,)
37+ return (
38+ exir_ops .edge .aten .div .Tensor_mode ,
39+ exir_ops .edge .aten .full_like .default ,
40+ )
3841 if op in aten_floor_divide_ops :
39- return (torch .ops .aten .div .Tensor_mode ,)
42+ return (
43+ torch .ops .aten .div .Tensor_mode ,
44+ torch .ops .aten .full_like .default ,
45+ )
4046
4147 raise RuntimeError (f"Can't get floor_div decomposition for op { op } " )
4248
@@ -52,11 +58,16 @@ def call_operator(self, op, args, kwargs, meta):
5258 if op not in (edge_floor_divide_ops + aten_floor_divide_ops ):
5359 return super ().call_operator (op , args , kwargs , meta , updated = False )
5460
55- (div_op ,) = get_floor_divide_decomposition (op )
61+ (div_op , full_op ) = get_floor_divide_decomposition (op )
5662
5763 input = args [0 ]
5864 other = args [1 ]
5965
66+ if isinstance (other , int ):
67+ other = super ().call_operator (
68+ full_op , (input , other ), {}, meta , updated = False
69+ )
70+
6071 div_node = super ().call_operator (
6172 div_op , (input , other ), {"rounding_mode" : "floor" }, meta , updated = True
6273 )
0 commit comments