|
7 | 7 | #include "caffe/util/math_functions.hpp" |
8 | 8 | #include "caffe/util/rng.hpp" |
9 | 9 |
|
| 10 | +#define SIGNED_SATURATE_MAX 2047 |
| 11 | +#define SIGNED_SATURATE_MIN -2048 |
| 12 | +#define UNSIGNED_SATURATE_MAX 4095 |
| 13 | +#define SIGNED_8BIT_SATURATE_MAX 127 |
| 14 | +#define SIGNED_8BIT_SATURATE_MIN -128 |
| 15 | +#define UNSIGNED_8BIT_SATURATE_MAX 255 |
| 16 | + |
10 | 17 | namespace caffe { |
11 | 18 |
|
12 | 19 | template<> |
@@ -397,4 +404,95 @@ void caffe_cpu_scale<double>(const int n, const double alpha, const double *x, |
397 | 404 | cblas_dscal(n, alpha, y, 1); |
398 | 405 | } |
399 | 406 |
|
| 407 | +template <typename Dtype> |
| 408 | +void caffe_cpu_universal_saturate(const int n, Dtype* x, Dtype SATURATE_MAX, Dtype SATURATE_MIN) { |
| 409 | + for (int i = 0; i < n; ++i) { |
| 410 | + if (x[i] > SATURATE_MAX) { |
| 411 | + x[i] = SATURATE_MAX; |
| 412 | + } |
| 413 | + if (x[i] < SATURATE_MIN) { |
| 414 | + x[i] = SATURATE_MIN; |
| 415 | + } |
| 416 | + } |
| 417 | +} |
| 418 | + |
| 419 | +template <> |
| 420 | +void caffe_cpu_signed_saturate<float>(const int n, float* x) { |
| 421 | + caffe_cpu_universal_saturate<float>(n, x, SIGNED_SATURATE_MAX, SIGNED_SATURATE_MIN); |
| 422 | +} |
| 423 | + |
| 424 | +template <> |
| 425 | +void caffe_cpu_signed_saturate<double>(const int n, double* x) { |
| 426 | + caffe_cpu_universal_saturate<double>(n, x, SIGNED_SATURATE_MAX, SIGNED_SATURATE_MIN); |
| 427 | +} |
| 428 | + |
| 429 | +template <> |
| 430 | +void caffe_cpu_unsigned_saturate<float>(const int n, float* x) { |
| 431 | + caffe_cpu_universal_saturate<float>(n, x, UNSIGNED_SATURATE_MAX, 0); |
| 432 | +} |
| 433 | + |
| 434 | +template <> |
| 435 | +void caffe_cpu_unsigned_saturate<double>(const int n, double* x) { |
| 436 | + caffe_cpu_universal_saturate<double>(n, x, UNSIGNED_SATURATE_MAX, 0); |
| 437 | +} |
| 438 | + |
| 439 | +template <> |
| 440 | +void caffe_cpu_signed_8bit_saturate<float>(const int n, float* x) { |
| 441 | + caffe_cpu_universal_saturate<float>(n, x, SIGNED_8BIT_SATURATE_MAX, SIGNED_8BIT_SATURATE_MIN); |
| 442 | +} |
| 443 | + |
| 444 | +template <> |
| 445 | +void caffe_cpu_signed_8bit_saturate<double>(const int n, double* x) { |
| 446 | + caffe_cpu_universal_saturate<double>(n, x, SIGNED_8BIT_SATURATE_MAX, SIGNED_8BIT_SATURATE_MIN); |
| 447 | +} |
| 448 | + |
| 449 | +template <> |
| 450 | +void caffe_cpu_unsigned_8bit_saturate<float>(const int n, float* x) { |
| 451 | + caffe_cpu_universal_saturate<float>(n, x, UNSIGNED_8BIT_SATURATE_MAX, 0); |
| 452 | +} |
| 453 | +template <> |
| 454 | +void caffe_cpu_unsigned_8bit_saturate<double>(const int n, double* x) { |
| 455 | + caffe_cpu_universal_saturate<double>(n, x, UNSIGNED_8BIT_SATURATE_MAX, 0); |
| 456 | +} |
| 457 | + |
| 458 | +template <typename Dtype> |
| 459 | +void caffe_cpu_round(const int n, Dtype *x) { |
| 460 | + for (int i = 0; i < n; ++i) { |
| 461 | + x[i] = std::rint(x[i]); |
| 462 | + } |
| 463 | +} |
| 464 | + |
| 465 | +template void caffe_cpu_round<float>(const int n, float* x); |
| 466 | + |
| 467 | +template void caffe_cpu_round<double>(const int n, double* x); |
| 468 | + |
| 469 | +template <typename Dtype> |
| 470 | +void caffe_cpu_quantize(const int n, Dtype* x, const Dtype scale, const int zero_point){ |
| 471 | + if (scale != Dtype(1.0)) { |
| 472 | + caffe_div_scalar<Dtype>(n, scale, x); |
| 473 | + caffe_cpu_round<Dtype>(n, x); |
| 474 | + } |
| 475 | + if (zero_point != 0) { |
| 476 | + caffe_add_scalar<Dtype>(n, Dtype(zero_point), x); |
| 477 | + } |
| 478 | +} |
| 479 | + |
| 480 | +template void caffe_cpu_quantize<float>(const int n, float* x, const float scale, const int zero_point); |
| 481 | + |
| 482 | +template void caffe_cpu_quantize<double>(const int n, double* x, const double scale, const int zero_point); |
| 483 | + |
| 484 | +template <typename Dtype> |
| 485 | +void caffe_cpu_dequantize(const int n, Dtype* x, const Dtype scale, const int zero_point){ |
| 486 | + if (zero_point != 0) { |
| 487 | + caffe_add_scalar<Dtype>(n, Dtype(-zero_point), x); |
| 488 | + } |
| 489 | + if (scale != Dtype(1.0)) { |
| 490 | + caffe_scal<Dtype>(n, scale, x); |
| 491 | + } |
| 492 | +} |
| 493 | + |
| 494 | +template void caffe_cpu_dequantize<float>(const int n, float* x, const float scale, const int zero_point); |
| 495 | + |
| 496 | +template void caffe_cpu_dequantize<double>(const int n, double* x, const double scale, const int zero_point); |
| 497 | + |
400 | 498 | } // namespace caffe |
0 commit comments