|
50 | 50 | "quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" |
51 | 51 | ) |
52 | 52 | lib.define( |
53 | | - "cadence::quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" |
| 53 | + "quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" |
| 54 | +) |
| 55 | +lib.define( |
| 56 | + "quantized_linear.per_tensor(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, " |
| 57 | + "SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset) -> Tensor" |
54 | 58 | ) |
55 | 59 |
|
56 | 60 | lib.define( |
@@ -123,6 +127,28 @@ def quantized_linear_meta( |
123 | 127 | return src.new_empty(out_size, dtype=src.dtype) |
124 | 128 |
|
125 | 129 |
|
| 130 | +@register_fake("cadence::quantized_linear.per_tensor") |
| 131 | +def quantized_linear_per_tensor_meta( |
| 132 | + src: torch.Tensor, |
| 133 | + weight: torch.Tensor, |
| 134 | + bias: torch.Tensor, |
| 135 | + in_zero_point: torch.SymInt, |
| 136 | + weight_zero_point: torch.SymInt, |
| 137 | + out_multiplier: torch.SymInt, |
| 138 | + out_shift: torch.SymInt, |
| 139 | + out_zero_point: torch.SymInt, |
| 140 | + offset: Optional[torch.Tensor], |
| 141 | +) -> torch.Tensor: |
| 142 | + # src comes in shape [leading_dims, in_dim] |
| 143 | + # weight comes in shape [out_dim, in_dim] |
| 144 | + # output comes in empty with shape [leading_dims, out_dim] |
| 145 | + out_size = list(src.size()) |
| 146 | + weight_size = list(weight.size()) |
| 147 | + assert len(weight_size) == 2 |
| 148 | + out_size[-1] = weight_size[0] |
| 149 | + return src.new_empty(out_size, dtype=src.dtype) |
| 150 | + |
| 151 | + |
126 | 152 | @register_fake("cadence::quantized_conv") |
127 | 153 | def quantized_conv_meta( |
128 | 154 | input: torch.Tensor, |
|
0 commit comments