|
22 | 22 | "full": exir_ops.edge.aten.full.default, |
23 | 23 | "lt": exir_ops.edge.aten.lt.Tensor, |
24 | 24 | "where": exir_ops.edge.aten.where.self, |
| 25 | + "mul": exir_ops.edge.aten.mul.Tensor, |
| 26 | + "sub": exir_ops.edge.aten.sub.Tensor, |
25 | 27 | } |
26 | 28 |
|
27 | 29 | aten_unary = { |
|
31 | 33 | "full": torch.ops.aten.full.default, |
32 | 34 | "lt": torch.ops.aten.lt.Tensor, |
33 | 35 | "where": torch.ops.aten.where.self, |
| 36 | + "mul": torch.ops.aten.mul.Tensor, |
| 37 | + "sub": torch.ops.aten.sub.Tensor, |
34 | 38 | } |
35 | 39 |
|
36 | 40 |
|
@@ -70,13 +74,57 @@ def call_operator(self, op, args, kwargs, meta): |
70 | 74 | return q |
71 | 75 |
|
72 | 76 | if rounding_mode == "floor": |
73 | | - return super().call_operator(opset["floor"], (q,), {}, meta) |
| 77 | + q_raw = q |
| 78 | + |
| 79 | + # trunc(q_raw) = where(q_raw < 0, ceil(q_raw), floor(q_raw)) |
| 80 | + q_floor = super().call_operator(opset["floor"], (q_raw,), {}, meta) |
| 81 | + q_ceil = super().call_operator(opset["ceil"], (q_raw,), {}, meta) |
| 82 | + |
| 83 | + # a zero tensor with the right shape |
| 84 | + out_shape = (1,) * len(meta["val"].size()) |
| 85 | + zero = super().call_operator( |
| 86 | + opset["full"], |
| 87 | + args=(out_shape, 0.0), |
| 88 | + kwargs={}, |
| 89 | + meta=meta, |
| 90 | + ) |
| 91 | + |
| 92 | + is_neg = super().call_operator(opset["lt"], (q_raw, zero), {}, meta) |
| 93 | + q_trunc = super().call_operator( |
| 94 | + opset["where"], (is_neg, q_ceil, q_floor), {}, meta |
| 95 | + ) |
| 96 | + |
| 97 | + # r = a - q_trunc * b (true remainder under truncation) |
| 98 | + q_times_b = super().call_operator(opset["mul"], (q_trunc, b), {}, meta) |
| 99 | + r = super().call_operator(opset["sub"], (a, q_times_b), {}, meta) |
| 100 | + |
| 101 | + # Decide if we need to subtract 1: |
| 102 | + # for b > 0, adjust if r < 0; for b < 0, adjust if r > 0. |
| 103 | + b_pos = super().call_operator(opset["lt"], (zero, b), {}, meta) # b > 0 |
| 104 | + r_lt0 = super().call_operator(opset["lt"], (r, zero), {}, meta) # r < 0 |
| 105 | + r_gt0 = super().call_operator(opset["lt"], (zero, r), {}, meta) # r > 0 |
| 106 | + |
| 107 | + adjust_if = super().call_operator( |
| 108 | + opset["where"], (b_pos, r_lt0, r_gt0), {}, meta |
| 109 | + ) |
| 110 | + |
| 111 | + one = super().call_operator( |
| 112 | + opset["full"], |
| 113 | + args=(out_shape, 1.0), |
| 114 | + kwargs={}, |
| 115 | + meta=meta, |
| 116 | + ) |
| 117 | + q_minus_1 = super().call_operator(opset["sub"], (q_trunc, one), {}, meta) |
| 118 | + |
| 119 | + return super().call_operator( |
| 120 | + opset["where"], (adjust_if, q_minus_1, q_trunc), {}, meta |
| 121 | + ) |
74 | 122 |
|
75 | 123 | if rounding_mode == "trunc": |
76 | 124 | zero = super().call_operator( |
77 | 125 | opset["full"], |
78 | 126 | args=((1,) * len(meta["val"].size()), 0.0), |
79 | | - kwargs={"dtype": torch.float32}, |
| 127 | + kwargs={}, |
80 | 128 | meta=meta, |
81 | 129 | ) |
82 | 130 | lt0 = self.call_operator(opset["lt"], (q, zero), {}, meta) |
|
0 commit comments