|
8 | 8 |
|
9 | 9 | #include "cortex_m_ops_common.h" |
10 | 10 |
|
| 11 | +// Include CMSIS-NN headers with C linkage |
| 12 | +extern "C" { |
| 13 | +#include "arm_nnfunctions.h" |
| 14 | +} |
| 15 | + |
11 | 16 | namespace cortex_m { |
12 | 17 | namespace native { |
13 | 18 | using KernelRuntimeContext = torch::executor::KernelRuntimeContext; |
@@ -54,19 +59,15 @@ Tensor& quantized_add_out( |
54 | 59 | "quantized_add_out: input1_int8.sizes() = %zu", |
55 | 60 | input1_int8.sizes().size()); |
56 | 61 |
|
57 | | - // FIX: Use template types that ExecutorTorch definitely provides |
58 | | - // Use to<int64_t>() and to<double>() which are commonly instantiated |
59 | | - int32_t zp1 = static_cast<int32_t>(input1_zero_point.to<int64_t>()); |
60 | | - int32_t input1_mult = static_cast<int32_t>(input1_multiplier.to<int64_t>()); |
61 | | - int input1_shift_val = static_cast<int>(input1_shift.to<int64_t>()); |
62 | | - |
63 | | - int32_t zp2 = static_cast<int32_t>(input2_zero_point.to<int64_t>()); |
64 | | - int32_t input2_mult = static_cast<int32_t>(input2_multiplier.to<int64_t>()); |
65 | | - int input2_shift_val = static_cast<int>(input2_shift.to<int64_t>()); |
66 | | - |
67 | | - int32_t out_zp = static_cast<int32_t>(output_zero_point.to<int64_t>()); |
68 | | - int32_t output_mult = static_cast<int32_t>(output_multiplier.to<int64_t>()); |
69 | | - int output_shift_val = static_cast<int>(output_shift.to<int64_t>()); |
| 62 | + int32_t zp1 = extractScalarToInt32(input1_zero_point); |
| 63 | + int32_t input1_mult = extractScalarToInt32(input1_multiplier); |
| 64 | + int input1_shift_val = extractScalarToInt(input1_shift); |
| 65 | + int32_t zp2 = extractScalarToInt32(input2_zero_point); |
| 66 | + int32_t input2_mult = extractScalarToInt32(input2_multiplier); |
| 67 | + int input2_shift_val = extractScalarToInt(input2_shift); |
| 68 | + int32_t out_zp = extractScalarToInt32(output_zero_point); |
| 69 | + int32_t output_mult = extractScalarToInt32(output_multiplier); |
| 70 | + int output_shift_val = extractScalarToInt(output_shift); |
70 | 71 |
|
71 | 72 | // Left shift to maximize precision (tune as needed) |
72 | 73 | const int32_t left_shift = 20; |
|
0 commit comments