|
22 | 22 | "full": exir_ops.edge.aten.full.default, |
23 | 23 | "lt": exir_ops.edge.aten.lt.Tensor, |
24 | 24 | "where": exir_ops.edge.aten.where.self, |
25 | | - "mul": exir_ops.edge.aten.mul.Tensor, |
26 | | - "sub": exir_ops.edge.aten.sub.Tensor, |
27 | 25 | } |
28 | 26 |
|
29 | 27 | aten_unary = { |
|
33 | 31 | "full": torch.ops.aten.full.default, |
34 | 32 | "lt": torch.ops.aten.lt.Tensor, |
35 | 33 | "where": torch.ops.aten.where.self, |
36 | | - "mul": torch.ops.aten.mul.Tensor, |
37 | | - "sub": torch.ops.aten.sub.Tensor, |
38 | 34 | } |
39 | 35 |
|
40 | 36 |
|
@@ -74,57 +70,13 @@ def call_operator(self, op, args, kwargs, meta): |
74 | 70 | return q |
75 | 71 |
|
76 | 72 | if rounding_mode == "floor": |
77 | | - q_raw = q |
78 | | - |
79 | | - # trunc(q_raw) = where(q_raw < 0, ceil(q_raw), floor(q_raw)) |
80 | | - q_floor = super().call_operator(opset["floor"], (q_raw,), {}, meta) |
81 | | - q_ceil = super().call_operator(opset["ceil"], (q_raw,), {}, meta) |
82 | | - |
83 | | - # a zero tensor with the right shape |
84 | | - out_shape = (1,) * len(meta["val"].size()) |
85 | | - zero = super().call_operator( |
86 | | - opset["full"], |
87 | | - args=(out_shape, 0.0), |
88 | | - kwargs={}, |
89 | | - meta=meta, |
90 | | - ) |
91 | | - |
92 | | - is_neg = super().call_operator(opset["lt"], (q_raw, zero), {}, meta) |
93 | | - q_trunc = super().call_operator( |
94 | | - opset["where"], (is_neg, q_ceil, q_floor), {}, meta |
95 | | - ) |
96 | | - |
97 | | - # r = a - q_trunc * b (true remainder under truncation) |
98 | | - q_times_b = super().call_operator(opset["mul"], (q_trunc, b), {}, meta) |
99 | | - r = super().call_operator(opset["sub"], (a, q_times_b), {}, meta) |
100 | | - |
101 | | - # Decide if we need to subtract 1: |
102 | | - # for b > 0, adjust if r < 0; for b < 0, adjust if r > 0. |
103 | | - b_pos = super().call_operator(opset["lt"], (zero, b), {}, meta) # b > 0 |
104 | | - r_lt0 = super().call_operator(opset["lt"], (r, zero), {}, meta) # r < 0 |
105 | | - r_gt0 = super().call_operator(opset["lt"], (zero, r), {}, meta) # r > 0 |
106 | | - |
107 | | - adjust_if = super().call_operator( |
108 | | - opset["where"], (b_pos, r_lt0, r_gt0), {}, meta |
109 | | - ) |
110 | | - |
111 | | - one = super().call_operator( |
112 | | - opset["full"], |
113 | | - args=(out_shape, 1.0), |
114 | | - kwargs={}, |
115 | | - meta=meta, |
116 | | - ) |
117 | | - q_minus_1 = super().call_operator(opset["sub"], (q_trunc, one), {}, meta) |
118 | | - |
119 | | - return super().call_operator( |
120 | | - opset["where"], (adjust_if, q_minus_1, q_trunc), {}, meta |
121 | | - ) |
| 73 | + return super().call_operator(opset["floor"], (q,), {}, meta) |
122 | 74 |
|
123 | 75 | if rounding_mode == "trunc": |
124 | 76 | zero = super().call_operator( |
125 | 77 | opset["full"], |
126 | 78 | args=((1,) * len(meta["val"].size()), 0.0), |
127 | | - kwargs={}, |
| 79 | + kwargs={"dtype": torch.float32}, |
128 | 80 | meta=meta, |
129 | 81 | ) |
130 | 82 | lt0 = self.call_operator(opset["lt"], (q, zero), {}, meta) |
|
0 commit comments