Skip to content

Commit 9a40718

Browse files
zhuyuhua-vZhiweiYan-96fengyuan14xiaolil1
authored
Rc4/separate memory management for quantizer (#2580)
* Refactor q reorder design (#2471) * Refactor q reorder design * Initialize args mark in constructor * Remove exclusive relation on sc&zp setting between src&dst * Change quantized_reorder api * Use cached sc&zp in quant/dequant --------- Co-authored-by: Zhiwei <[email protected]> Co-authored-by: Feng Yuan <[email protected]> Co-authored-by: xiaolil1 <[email protected]>
1 parent f7e76a6 commit 9a40718

File tree

10 files changed

+231
-107
lines changed

10 files changed

+231
-107
lines changed

csrc/gpu/aten/operators/ReQuantization.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,21 @@ Tensor requantize(
2525
auto reorder_attr = xpu::oneDNN::ReorderAttr();
2626
int mask = 0;
2727
auto scale_in = src.is_quantized() ? static_cast<float>(src.q_scale()) : 1.f;
28-
auto requant_scale = static_cast<float>(1.f / (scale_out / scale_in));
28+
auto requant_scale = static_cast<float>((scale_out / scale_in));
2929

30-
Tensor dnn_scale =
31-
at::ones(1, at::dtype(at::kFloat).device(at::kXPU)) * requant_scale;
30+
Tensor dnn_scale = at::empty({1}, at::dtype(at::kFloat).device(at::kXPU))
31+
.fill_(requant_scale);
3232
// TODO: Remove workaround for dnnl symmetric quantization
33-
Tensor dnn_zero_point =
34-
at::ones(1, at::dtype(at::kInt).device(at::kXPU)) * zero_point_out;
33+
Tensor dnn_zero_point = at::zeros({1}, at::dtype(at::kInt).device(at::kXPU));
3534
reorder_attr.set_dst_sc_and_zp_mask(mask);
3635
xpu::oneDNN::quantized_reorder(
37-
src, dst_, dnn_scale, dnn_zero_point, reorder_attr);
36+
src,
37+
dst_,
38+
/*src_scale=*/Tensor(),
39+
/*src_zero_point=*/Tensor(),
40+
dnn_scale,
41+
dnn_zero_point,
42+
reorder_attr);
3843

3944
return dst_;
4045
}

csrc/gpu/aten/quantized/DeQuantization.cpp

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ using namespace dnnl;
1414
using namespace at::native;
1515
using namespace xpu::dpcpp;
1616
using namespace xpu::oneDNN;
17+
using namespace at::AtenIpexTypeQuantizedXPU;
1718

1819
namespace at {
1920
namespace AtenIpexTypeXPU {
@@ -25,25 +26,38 @@ Tensor dequantize_tensor_per_tensor_affine(
2526
int64_t zero_point) {
2627
ReorderAttr rattr = ReorderAttr();
2728
int mask = 0;
28-
auto q_ctx = DPCPPTensorContext::get_tensor_ctx(qtensor);
29-
// TODO: Remove workaround for dnnl symmetric quantization
30-
float true_scale = ((q_ctx.is_plain() ? get_onednn_dtype(qtensor)
31-
: q_ctx.meta().get_data_type()) ==
32-
memory::data_type::u8 &&
33-
qtensor.q_zero_point() == 128)
34-
? static_cast<float>(scale / 2)
35-
: static_cast<float>(scale);
3629
rattr.set_src_sc_and_zp_mask(mask);
3730

38-
// See [Note: Scale setting for reorder]
39-
Tensor dnn_scale =
40-
at::ones(1, at::dtype(at::kFloat).device(at::kXPU)) * true_scale;
41-
// TODO: Remove workaround for dnnl symmetric quantization
42-
Tensor dnn_zero_point = at::zeros(1, at::dtype(at::kInt).device(at::kXPU));
43-
4431
Tensor rtensor_ = at::empty(qtensor.sizes(), rtensor.options());
45-
xpu::oneDNN::quantized_reorder(
46-
qtensor, rtensor_, dnn_scale, dnn_zero_point, rattr);
32+
if (is_opaque_u8(qtensor)) {
33+
Tensor dnn_scale =
34+
at::empty({1}, at::dtype(at::kFloat).device(at::kXPU)).fill_(scale);
35+
Tensor dnn_zero_point =
36+
at::zeros({1}, at::dtype(at::kInt).device(at::kXPU));
37+
38+
// See [Note: Scale setting for reorder]
39+
xpu::oneDNN::quantized_reorder(
40+
qtensor,
41+
rtensor_,
42+
dnn_scale,
43+
dnn_zero_point,
44+
/*dst_scale=*/Tensor(),
45+
/*dst_zero_point=*/Tensor(),
46+
rattr);
47+
} else {
48+
// See [Note: Scale setting for reorder]
49+
xpu::oneDNN::quantized_reorder(
50+
qtensor,
51+
rtensor_,
52+
q_scale_ptr(qtensor),
53+
q_zero_point_ptr(qtensor),
54+
/*dst_scale=*/nullptr,
55+
/*dst_zero_point=*/nullptr,
56+
{1},
57+
{1},
58+
rattr);
59+
}
60+
4761
return rtensor_;
4862
}
4963

@@ -91,7 +105,13 @@ Tensor dequantize_tensor_per_channel_affine(
91105

92106
Tensor rtensor_ = empty_opaque_tensor(r_md, rtensor.options(), c10::nullopt);
93107
xpu::oneDNN::quantized_reorder(
94-
qtensor, rtensor_, dnn_scale, dnn_zero_point, rattr);
108+
qtensor,
109+
rtensor_,
110+
dnn_scale,
111+
dnn_zero_point,
112+
/*dst_scale=*/Tensor(),
113+
/*dst_zero_point=*/Tensor(),
114+
rattr);
95115

96116
return rtensor_;
97117
}

csrc/gpu/aten/quantized/QTensor.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,20 +122,19 @@ int64_t q_per_channel_axis(const Tensor& self) {
122122
return at::native::q_per_channel_axis(self);
123123
}
124124

125-
Tensor q_scale_tensor(const Tensor& self) {
125+
float* q_scale_ptr(const Tensor& self) {
126126
auto quantizer = get_qtensorimpl(self)->quantizer();
127127
TORCH_CHECK(quantizer->qscheme() == kPerTensorAffine);
128128
return static_cast<DPCPPPerTensorAffineQuantizer*>(quantizer.get())
129-
->scale_tensor();
129+
->scale_ptr();
130130
}
131131

132-
Tensor q_zero_point_tensor(const Tensor& self) {
132+
int32_t* q_zero_point_ptr(const Tensor& self) {
133133
auto quantizer = get_qtensorimpl(self)->quantizer();
134134
TORCH_CHECK(quantizer->qscheme() == kPerTensorAffine);
135135
return static_cast<DPCPPPerTensorAffineQuantizer*>(quantizer.get())
136-
->zero_point_tensor();
136+
->zero_point_ptr();
137137
}
138-
139138
Tensor& set_(
140139
Tensor& self,
141140
Storage storage,

csrc/gpu/aten/quantized/QTensor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ Tensor q_per_channel_zero_points(const Tensor& self);
3636

3737
int64_t q_per_channel_axis(const Tensor& self);
3838

39-
Tensor q_scale_tensor(const Tensor& self);
39+
float* q_scale_ptr(const Tensor& self);
4040

41-
Tensor q_zero_point_tensor(const Tensor& self);
41+
int32_t* q_zero_point_ptr(const Tensor& self);
4242

4343
Tensor& set_(
4444
Tensor& self,

csrc/gpu/aten/quantized/Quantization.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,13 @@ Tensor quantize_tensor_per_channel_affine(
7777
Tensor dnn_zero_point =
7878
at::zeros_like(zero_points, dtype(at::kInt).device(at::kXPU));
7979
xpu::oneDNN::quantized_reorder(
80-
rtensor, qtensor, dnn_scale, dnn_zero_point, rattr);
80+
rtensor,
81+
qtensor,
82+
/*src_scale=*/Tensor(),
83+
/*src_zero_point=*/Tensor(),
84+
dnn_scale,
85+
dnn_zero_point,
86+
rattr);
8187

8288
return qtensor;
8389
}
@@ -134,15 +140,27 @@ Tensor quantize_tensor_per_tensor_affine(
134140
AtenIpexTypeXPU::empty_opaque_qtensor(q_md, c10::nullopt, quantizer);
135141

136142
xpu::oneDNN::quantized_reorder(
137-
rtensor, qtensor_opt, dnn_scale, dnn_zero_point, rattr);
143+
rtensor,
144+
qtensor_opt,
145+
/*src_scale=*/Tensor(),
146+
/*src_zero_point=*/Tensor(),
147+
dnn_scale,
148+
dnn_zero_point,
149+
rattr);
138150
auto q_opt_ctx =
139151
at::AtenIpexTypeXPU::DPCPPTensorContext::release_tensor_ctx(
140152
qtensor_opt);
141153
at::AtenIpexTypeXPU::DPCPPTensorContext::set_tensor_ctx(
142154
qtensor, std::move(q_opt_ctx));
143155
} else {
144156
xpu::oneDNN::quantized_reorder(
145-
rtensor, qtensor, dnn_scale, dnn_zero_point, rattr);
157+
rtensor,
158+
qtensor,
159+
/*src_scale=*/Tensor(),
160+
/*srd_zero_point=*/Tensor(),
161+
dnn_scale,
162+
dnn_zero_point,
163+
rattr);
146164
}
147165

148166
return qtensor;

csrc/gpu/aten/quantized/Quantizer.h

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,50 @@
44
#include <quantized/DeQuantization.h>
55
#include <quantized/QTensor.h>
66
#include <quantized/Quantization.h>
7+
#include <runtime/Utils.h>
78
#include <utils/LRUCache.h>
89

910
namespace at {
1011
namespace AtenIpexTypeQuantizedXPU {
1112

13+
using namespace xpu::dpcpp;
14+
15+
template <typename scale_t_, typename zp_t_>
16+
class XPUQuantizerBase {
17+
public:
18+
using scale_t = scale_t_;
19+
using zp_t = zp_t_;
20+
using scale_ptr_t = std::shared_ptr<scale_t>;
21+
using zp_ptr_t = std::shared_ptr<zp_t>;
22+
23+
public:
24+
XPUQuantizerBase() = default;
25+
26+
XPUQuantizerBase(size_t size, sycl::queue& q)
27+
: scale_ptr_(
28+
sycl::malloc_device<scale_t>(size * sizeof(scale_t), q),
29+
[=](scale_t* ptr) { sycl::free(ptr, q); }),
30+
zp_ptr_(
31+
sycl::malloc_device<zp_t>(size * sizeof(zp_t), q),
32+
[=](zp_t* ptr) { sycl::free(ptr, q); }) {}
33+
scale_t* scale_ptr() {
34+
return scale_ptr_.get();
35+
}
36+
37+
zp_t* zero_point_ptr() {
38+
return zp_ptr_.get();
39+
}
40+
41+
private:
42+
scale_ptr_t scale_ptr_;
43+
zp_ptr_t zp_ptr_;
44+
};
45+
1246
struct DPCPPPerTensorAffineQuantizer : public AffineQuantizer {
47+
using QuantizerBaseType = XPUQuantizerBase<float, int32_t>;
48+
using scale_t = QuantizerBaseType::scale_t;
49+
using zp_t = QuantizerBaseType::zp_t;
50+
1351
explicit DPCPPPerTensorAffineQuantizer(
1452
ScalarType scalar_type,
1553
double scale,
@@ -22,17 +60,21 @@ struct DPCPPPerTensorAffineQuantizer : public AffineQuantizer {
2260
}
2361
// TODO: Modify this line after asymmetric enabled
2462
xpu::dpcpp::create_key(key_sc_zp, dnn_scale, 0);
25-
bool key_found = xpu::dpcpp::find_key<std::pair<Tensor, Tensor>>(key_sc_zp);
63+
bool key_found = xpu::dpcpp::find_key<QuantizerBaseType>(key_sc_zp);
2664
if (key_found) {
27-
std::tie(scale_tensor_, zero_point_tensor_) =
28-
xpu::dpcpp::fetch_m<std::pair<Tensor, Tensor>>(key_sc_zp);
65+
base_ = xpu::dpcpp::fetch_m<QuantizerBaseType>(key_sc_zp);
2966
} else {
30-
scale_tensor_ = at::empty({1}, at::dtype(kFloat).device(at::kXPU))
31-
.fill_(static_cast<float>(dnn_scale));
32-
// TODO: Modify this line after asymmetric enabled
33-
zero_point_tensor_ = at::zeros({1}, at::dtype(kInt).device(at::kXPU));
34-
xpu::dpcpp::fetch_or_create_m<std::pair<Tensor, Tensor>>(
35-
key_sc_zp, scale_tensor_, zero_point_tensor_);
67+
base_ = QuantizerBaseType(1, dpcppGetCurrentQueue());
68+
69+
scale_t* sc_ptr = base_.scale_ptr();
70+
scale_t _scale = (scale_t)dnn_scale;
71+
dpcppGetCurrentQueue().single_task([=]() { sc_ptr[0] = _scale; });
72+
73+
zp_t* zp_ptr = base_.zero_point_ptr();
74+
zp_t _zp = (zp_t)0;
75+
dpcppGetCurrentQueue().single_task([=]() { zp_ptr[0] = _zp; });
76+
77+
xpu::dpcpp::fetch_or_create_m<QuantizerBaseType>(key_sc_zp, base_);
3678
}
3779
}
3880

@@ -80,12 +122,12 @@ struct DPCPPPerTensorAffineQuantizer : public AffineQuantizer {
80122
return zero_point_;
81123
}
82124

83-
Tensor scale_tensor() {
84-
return scale_tensor_;
125+
scale_t* scale_ptr() {
126+
return base_.scale_ptr();
85127
}
86128

87-
Tensor zero_point_tensor() {
88-
return zero_point_tensor_;
129+
zp_t* zero_point_ptr() {
130+
return base_.zero_point_ptr();
89131
}
90132

91133
bool equalTo(QuantizerPtr other) const override {
@@ -107,8 +149,7 @@ struct DPCPPPerTensorAffineQuantizer : public AffineQuantizer {
107149
const double scale_;
108150
// We use int64_t for consistency with Python
109151
const int64_t zero_point_;
110-
Tensor scale_tensor_;
111-
Tensor zero_point_tensor_;
152+
QuantizerBaseType base_;
112153
};
113154

114155
struct DPCPPPerChannelAffineQuantizer : public AffineQuantizer {

csrc/gpu/oneDNN/Matmul.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,7 @@ static inline void matmul(
363363
.fill_(m1.q_scale());
364364
m1_sc_m = dpcpp_onednn_memory(m1_sc_md, engine, m1_sc.data_ptr());
365365
} else {
366-
m1_sc = at::AtenIpexTypeQuantizedXPU::q_scale_tensor(m1);
367-
m1_sc_m = dpcpp_onednn_memory(m1_sc_md, engine, m1_sc.data_ptr());
366+
m1_sc_m = dpcpp_onednn_memory(m1_sc_md, engine, q_scale_ptr(m1));
368367
}
369368
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, m1_sc_m});
370369

@@ -373,8 +372,7 @@ static inline void matmul(
373372
if (m2.is_quantized()) {
374373
memory::desc m2_sc_md =
375374
memory::desc({1}, memory::data_type::f32, memory::format_tag::x);
376-
m2_sc = at::AtenIpexTypeQuantizedXPU::q_scale_tensor(m2);
377-
m2_sc_m = dpcpp_onednn_memory(m2_sc_md, engine, m2_sc.data_ptr());
375+
m2_sc_m = dpcpp_onednn_memory(m2_sc_md, engine, q_scale_ptr(m2));
378376
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, m2_sc_m});
379377
}
380378

@@ -383,8 +381,7 @@ static inline void matmul(
383381
if (dst.is_quantized()) {
384382
memory::desc dst_sc_md =
385383
memory::desc({1}, memory::data_type::f32, memory::format_tag::x);
386-
dst_sc = at::AtenIpexTypeQuantizedXPU::q_scale_tensor(dst);
387-
dst_sc_m = dpcpp_onednn_memory(dst_sc_md, engine, dst_sc.data_ptr());
384+
dst_sc_m = dpcpp_onednn_memory(dst_sc_md, engine, q_scale_ptr(dst));
388385
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_m});
389386
}
390387

@@ -421,9 +418,7 @@ static inline void matmul(
421418
if (is_per_tensor_quantized) {
422419
memory::desc wgh_sc_md =
423420
memory::desc({1}, memory::data_type::f32, memory::format_tag::x);
424-
Tensor wgh_sc = at::AtenIpexTypeQuantizedXPU::q_scale_tensor(m2);
425-
memory wgh_sc_m =
426-
dpcpp_onednn_memory(wgh_sc_md, engine, wgh_sc.data_ptr());
421+
memory wgh_sc_m = dpcpp_onednn_memory(wgh_sc_md, engine, q_scale_ptr(m2));
427422
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wgh_sc_m});
428423

429424
#ifdef BUILD_PRIOR_SYMM_QUANT

csrc/gpu/oneDNN/QConv.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,7 @@ static at::Tensor quantized_convolution(
485485
.fill_(static_cast<float>(src.q_scale()));
486486
src_sc_m = dpcpp_onednn_memory(src_sc_md, engine, src_sc.data_ptr());
487487
} else {
488-
src_sc = at::AtenIpexTypeQuantizedXPU::q_scale_tensor(src);
489-
src_sc_m = dpcpp_onednn_memory(src_sc_md, engine, src_sc.data_ptr());
488+
src_sc_m = dpcpp_onednn_memory(src_sc_md, engine, q_scale_ptr(src));
490489
}
491490
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_sc_m});
492491

@@ -530,10 +529,9 @@ static at::Tensor quantized_convolution(
530529
#endif
531530

532531
if (wgh.qscheme() == kPerTensorAffine) {
533-
Tensor wgh_sc = at::AtenIpexTypeQuantizedXPU::q_scale_tensor(wgh);
534532
memory::desc wgh_sc_md =
535533
memory::desc({1}, memory::data_type::f32, memory::format_tag::x);
536-
memory wgh_sc_m = dpcpp_onednn_memory(wgh_sc_md, engine, wgh_sc.data_ptr());
534+
memory wgh_sc_m = dpcpp_onednn_memory(wgh_sc_md, engine, q_scale_ptr(wgh));
537535
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wgh_sc_m});
538536

539537
#ifdef BUILD_PRIOR_SYMM_QUANT

csrc/gpu/oneDNN/QDeconv.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -367,8 +367,7 @@ static Tensor quantized_deconvolution(
367367
.fill_(static_cast<float>(src.q_scale()));
368368
src_sc_m = dpcpp_onednn_memory(src_sc_md, engine, src_sc.data_ptr());
369369
} else {
370-
src_sc = at::AtenIpexTypeQuantizedXPU::q_scale_tensor(src);
371-
src_sc_m = dpcpp_onednn_memory(src_sc_md, engine, src_sc.data_ptr());
370+
src_sc_m = dpcpp_onednn_memory(src_sc_md, engine, q_scale_ptr(src));
372371
}
373372
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_sc_m});
374373

@@ -387,10 +386,9 @@ static Tensor quantized_deconvolution(
387386
}
388387
#endif
389388

390-
Tensor dst_sc = at::AtenIpexTypeQuantizedXPU::q_scale_tensor(dst);
391389
memory::desc dst_sc_md =
392390
memory::desc({1}, memory::data_type::f32, memory::format_tag::x);
393-
memory dst_sc_m = dpcpp_onednn_memory(dst_sc_md, engine, dst_sc.data_ptr());
391+
memory dst_sc_m = dpcpp_onednn_memory(dst_sc_md, engine, q_scale_ptr(dst));
394392
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_m});
395393

396394
#ifdef BUILD_PRIOR_SYMM_QUANT
@@ -407,10 +405,9 @@ static Tensor quantized_deconvolution(
407405
#endif
408406

409407
if (wgh.qscheme() == kPerTensorAffine) {
410-
Tensor wgh_sc = at::AtenIpexTypeQuantizedXPU::q_scale_tensor(wgh);
411408
memory::desc wgh_sc_md =
412409
memory::desc({1}, memory::data_type::f32, memory::format_tag::x);
413-
memory wgh_sc_m = dpcpp_onednn_memory(wgh_sc_md, engine, wgh_sc.data_ptr());
410+
memory wgh_sc_m = dpcpp_onednn_memory(wgh_sc_md, engine, q_scale_ptr(wgh));
414411
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wgh_sc_m});
415412

416413
#ifdef BUILD_PRIOR_SYMM_QUANT

0 commit comments

Comments
 (0)