44#include < quantized/DeQuantization.h>
55#include < quantized/QTensor.h>
66#include < quantized/Quantization.h>
7+ #include < runtime/Utils.h>
78#include < utils/LRUCache.h>
89
910namespace at {
1011namespace AtenIpexTypeQuantizedXPU {
1112
13+ using namespace xpu ::dpcpp;
14+
15+ template <typename scale_t_, typename zp_t_>
16+ class XPUQuantizerBase {
17+ public:
18+ using scale_t = scale_t_;
19+ using zp_t = zp_t_;
20+ using scale_ptr_t = std::shared_ptr<scale_t >;
21+ using zp_ptr_t = std::shared_ptr<zp_t >;
22+
23+ public:
24+ XPUQuantizerBase () = default ;
25+
26+ XPUQuantizerBase (size_t size, sycl::queue& q)
27+ : scale_ptr_(
28+ sycl::malloc_device<scale_t >(size * sizeof (scale_t ), q),
29+ [=](scale_t * ptr) { sycl::free (ptr, q); }),
30+ zp_ptr_ (
31+ sycl::malloc_device<zp_t >(size * sizeof (zp_t ), q),
32+ [=](zp_t * ptr) { sycl::free (ptr, q); }) {}
33+ scale_t * scale_ptr () {
34+ return scale_ptr_.get ();
35+ }
36+
37+ zp_t * zero_point_ptr () {
38+ return zp_ptr_.get ();
39+ }
40+
41+ private:
42+ scale_ptr_t scale_ptr_;
43+ zp_ptr_t zp_ptr_;
44+ };
45+
1246struct DPCPPPerTensorAffineQuantizer : public AffineQuantizer {
47+ using QuantizerBaseType = XPUQuantizerBase<float , int32_t >;
48+ using scale_t = QuantizerBaseType::scale_t ;
49+ using zp_t = QuantizerBaseType::zp_t ;
50+
1351 explicit DPCPPPerTensorAffineQuantizer (
1452 ScalarType scalar_type,
1553 double scale,
@@ -22,17 +60,21 @@ struct DPCPPPerTensorAffineQuantizer : public AffineQuantizer {
2260 }
2361 // TODO: Modify this line after asymmetric enabled
2462 xpu::dpcpp::create_key (key_sc_zp, dnn_scale, 0 );
25- bool key_found = xpu::dpcpp::find_key<std::pair<Tensor, Tensor> >(key_sc_zp);
63+ bool key_found = xpu::dpcpp::find_key<QuantizerBaseType >(key_sc_zp);
2664 if (key_found) {
27- std::tie (scale_tensor_, zero_point_tensor_) =
28- xpu::dpcpp::fetch_m<std::pair<Tensor, Tensor>>(key_sc_zp);
65+ base_ = xpu::dpcpp::fetch_m<QuantizerBaseType>(key_sc_zp);
2966 } else {
30- scale_tensor_ = at::empty ({1 }, at::dtype (kFloat ).device (at::kXPU ))
31- .fill_ (static_cast <float >(dnn_scale));
32- // TODO: Modify this line after asymmetric enabled
33- zero_point_tensor_ = at::zeros ({1 }, at::dtype (kInt ).device (at::kXPU ));
34- xpu::dpcpp::fetch_or_create_m<std::pair<Tensor, Tensor>>(
35- key_sc_zp, scale_tensor_, zero_point_tensor_);
67+ base_ = QuantizerBaseType (1 , dpcppGetCurrentQueue ());
68+
69+ scale_t * sc_ptr = base_.scale_ptr ();
70+ scale_t _scale = (scale_t )dnn_scale;
71+ dpcppGetCurrentQueue ().single_task ([=]() { sc_ptr[0 ] = _scale; });
72+
73+ zp_t * zp_ptr = base_.zero_point_ptr ();
74+ zp_t _zp = (zp_t )0 ;
75+ dpcppGetCurrentQueue ().single_task ([=]() { zp_ptr[0 ] = _zp; });
76+
77+ xpu::dpcpp::fetch_or_create_m<QuantizerBaseType>(key_sc_zp, base_);
3678 }
3779 }
3880
@@ -80,12 +122,12 @@ struct DPCPPPerTensorAffineQuantizer : public AffineQuantizer {
80122 return zero_point_;
81123 }
82124
83- Tensor scale_tensor () {
84- return scale_tensor_ ;
125+ scale_t * scale_ptr () {
126+ return base_. scale_ptr () ;
85127 }
86128
87- Tensor zero_point_tensor () {
88- return zero_point_tensor_ ;
129+ zp_t * zero_point_ptr () {
130+ return base_. zero_point_ptr () ;
89131 }
90132
91133 bool equalTo (QuantizerPtr other) const override {
@@ -107,8 +149,7 @@ struct DPCPPPerTensorAffineQuantizer : public AffineQuantizer {
107149 const double scale_;
108150 // We use int64_t for consistency with Python
109151 const int64_t zero_point_;
110- Tensor scale_tensor_;
111- Tensor zero_point_tensor_;
152+ QuantizerBaseType base_;
112153};
113154
114155struct DPCPPPerChannelAffineQuantizer : public AffineQuantizer {
0 commit comments