Skip to content

Commit c6ea20b

Browse files
committed
Revert "Enable quant save/load through prepack fn registration (#3078)"
This reverts commit 7db2f1c.
1 parent 5af6dbb commit c6ea20b

File tree

11 files changed

+235
-87
lines changed

11 files changed

+235
-87
lines changed

csrc/gpu/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@ if (USE_PROFILER)
150150
list(APPEND IPEX_COMPILE_DEFINITIONS "USE_PROFILER")
151151
endif()
152152

153+
if (BUILD_JIT_QUANTIZATION_SAVE)
154+
list(APPEND IPEX_COMPILE_DEFINITIONS "BUILD_JIT_QUANTIZATION_SAVE")
155+
endif()
156+
153157
if (USE_SPLIT_FP64_LOOPS)
154158
list(APPEND IPEX_COMPILE_DEFINITIONS "USE_SPLIT_FP64_LOOPS")
155159
endif()

csrc/gpu/aten/operators/QConv_prepack.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#include <oneDNN/oneDNN.h>
33
#include <runtime/Utils.h>
44

5-
#include <ATen/native/quantized/PackedParams.h>
65
#include "comm/ParamUtils.h"
76

87
#include <quantized/QUtils.h>
@@ -124,18 +123,3 @@ TORCH_LIBRARY_IMPL(quantized, XPU, m) {
124123

125124
} // namespace AtenIpexTypeQuantizedXPU
126125
} // namespace at
127-
128-
int init_prepack_fn() {
129-
register_prepack<2>(
130-
at::QEngine::QXPU,
131-
at::AtenIpexTypeQuantizedXPU::PackedConvWeightQDPCPP<2>::prepack);
132-
register_prepack<3>(
133-
at::QEngine::QXPU,
134-
at::AtenIpexTypeQuantizedXPU::PackedConvWeightQDPCPP<3>::prepack);
135-
register_linear_prepack(
136-
at::QEngine::QXPU,
137-
at::AtenIpexTypeQuantizedXPU::PackedLinearWeightQDPCPP::prepack);
138-
return 1;
139-
}
140-
141-
auto xpu_prepack = init_prepack_fn();

csrc/gpu/aten/quantized/QTensor.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,7 @@ Tensor& set_(
131131
auto* self_ = self.unsafeGetTensorImpl();
132132
self_->set_storage_keep_dtype(storage);
133133
self_->set_storage_offset(storage_offset);
134-
if (strides.data() == nullptr) {
135-
self_->set_sizes_contiguous(sizes);
136-
} else {
137-
self_->set_sizes_and_strides(sizes, strides);
138-
}
134+
self_->set_sizes_and_strides(sizes, strides);
139135
return self;
140136
}
141137

csrc/gpu/aten/quantized/QUtils.cpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/native/TensorFactories.h>
3+
#include <ATen/quantized/QTensorImpl.h>
4+
#include <ATen/quantized/Quantizer.h>
5+
#include <c10/core/QScheme.h>
6+
#include <c10/core/TensorOptions.h>
7+
#include <c10/util/accumulate.h>
8+
#include <torch/custom_class.h>
9+
#include <torch/custom_class_detail.h>
10+
11+
#include <oneapi/dnnl/dnnl.hpp>
12+
#include <quantized/QUtils.h>
13+
14+
#ifdef BUILD_JIT_QUANTIZATION_SAVE
15+
// Following code is not in any namespace. This is due to
16+
// we align to PyTorch side. If any code is need added in this
17+
// file except packedparam serialization, please write it in a
18+
// proper namespace.
19+
// QConv prepack pickling method hacking
20+
template <int kSpatialDim = 2>
21+
torch::class_<ConvPackedParamsBase<kSpatialDim>> register_conv_params();
22+
23+
extern template torch::class_<ConvPackedParamsBase<2>> register_conv_params<
24+
2>();
25+
extern template torch::class_<ConvPackedParamsBase<3>> register_conv_params<
26+
3>();
27+
28+
template <int kSpatialDim = 2>
29+
ConvParamsSerializationTypeV2 serialize_conv(
30+
const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& params);
31+
extern template ConvParamsSerializationTypeV2 serialize_conv(
32+
const c10::intrusive_ptr<ConvPackedParamsBase<2>>& params);
33+
extern template ConvParamsSerializationTypeV2 serialize_conv(
34+
const c10::intrusive_ptr<ConvPackedParamsBase<3>>& params);
35+
36+
template <uint32_t kSpatialDim>
37+
ConvParamsSerializationTypeV3 parse_conv_serialized_state(c10::IValue v);
38+
39+
template <int kSpatialDim>
40+
int redefine_prepack() {
41+
auto conv_prepack_class = register_conv_params<kSpatialDim>();
42+
auto clsptr = torch::getCustomClass(
43+
"__torch__.torch.classes.quantized.Conv" + c10::to_string(kSpatialDim) +
44+
"dPackedParamsBase");
45+
clsptr->unsafeRemoveMethod("__getstate__");
46+
clsptr->unsafeRemoveMethod("__setstate__");
47+
conv_prepack_class.def_pickle(
48+
[](const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& params)
49+
-> ConvParamsSerializationType { // __getstate__
50+
return serialize_conv<kSpatialDim>(params);
51+
},
52+
// __setstate__ takes c10::IValue because we support parsing historical
53+
// serialization versions.
54+
[](c10::IValue v) -> c10::intrusive_ptr<
55+
ConvPackedParamsBase<kSpatialDim>> { // __setstate__
56+
ConvParamsSerializationTypeV3 state =
57+
parse_conv_serialized_state<kSpatialDim>(v);
58+
return deserialize_conv_dpcpp<kSpatialDim>(state);
59+
});
60+
return 0;
61+
}
62+
63+
template int redefine_prepack<2>();
64+
template int redefine_prepack<3>();
65+
66+
// QLinear prepack pickling method hacking
67+
torch::class_<LinearPackedParamsBase> register_linear_params();
68+
69+
int redefine_linear_prepack() {
70+
auto linear_prepack_class = register_linear_params();
71+
auto clsptr = torch::getCustomClass(
72+
"__torch__.torch.classes.quantized.LinearPackedParamsBase");
73+
clsptr->unsafeRemoveMethod("__getstate__");
74+
clsptr->unsafeRemoveMethod("__setstate__");
75+
using SerializationType = std::tuple<at::Tensor, c10::optional<at::Tensor>>;
76+
linear_prepack_class.def_pickle(
77+
[](const c10::intrusive_ptr<LinearPackedParamsBase>& params)
78+
-> SerializationType { // __getstate__
79+
at::Tensor weight;
80+
c10::optional<at::Tensor> bias;
81+
std::tie(weight, bias) = params->unpack();
82+
return std::make_tuple(std::move(weight), std::move(bias));
83+
},
84+
[](SerializationType state)
85+
-> c10::intrusive_ptr<LinearPackedParamsBase> { // __setstate__
86+
at::Tensor weight;
87+
c10::optional<at::Tensor> bias;
88+
weight = std::move(std::get<0>(state));
89+
bias = std::move(std::get<1>(state));
90+
91+
return at::AtenIpexTypeQuantizedXPU::PackedLinearWeightQDPCPP::prepack(
92+
std::move(weight), std::move(bias));
93+
});
94+
return 0;
95+
}
96+
97+
namespace {
98+
static auto conv2d_params = redefine_prepack<2>();
99+
static auto conv3d_params = redefine_prepack<3>();
100+
static auto linear_params = redefine_linear_prepack();
101+
} // namespace
102+
#endif

csrc/gpu/aten/quantized/QUtils.h

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
namespace xpu {
2222
namespace dpcpp {
23-
2423
// Note: [Opaque u8 tensor]
2524
// Due to the difference between oneDNN and PyTorch u8 quantization, we quant
2625
// tensor with kQUint8 and 128 zp to memory::data_type::s8 and 0 zp inside. This
@@ -327,3 +326,93 @@ struct PackedLinearWeightQDPCPP : public LinearPackedParamsBase {
327326

328327
} // namespace AtenIpexTypeQuantizedXPU
329328
} // namespace at
329+
330+
#ifdef BUILD_JIT_QUANTIZATION_SAVE
331+
332+
// Repeat torch type definition here again
333+
using ConvParamsSerializationTypeV2 = std::tuple<
334+
// version, for versions 2 and up
335+
std::string,
336+
// non-optional tensors
337+
std::vector<at::Tensor>,
338+
// optional tensors
339+
std::vector<c10::optional<at::Tensor>>>;
340+
using ConvParamsSerializationTypeV3 = std::tuple<
341+
// version, int for versions 3 and up
342+
int64_t,
343+
// configuration values
344+
std::vector<int64_t>,
345+
// optional tensors
346+
std::vector<c10::optional<at::Tensor>>>;
347+
348+
using ConvParamsSerializationType = ConvParamsSerializationTypeV2;
349+
350+
template <uint32_t kSpatialDim>
351+
c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> deserialize_conv_dpcpp(
352+
ConvParamsSerializationTypeV3 state) {
353+
int64_t version;
354+
std::vector<int64_t> config_vals;
355+
std::vector<c10::optional<at::Tensor>> tensors;
356+
357+
std::tie(version, config_vals, tensors) = state;
358+
TORCH_INTERNAL_ASSERT(
359+
version == 3, "Unexpected serialized qconv version: ", version);
360+
361+
TORCH_CHECK(tensors.size() == 3, "Wrong number of tensors", tensors.size());
362+
c10::optional<at::Tensor> weight = tensors[1];
363+
c10::optional<at::Tensor> bias = tensors[2];
364+
TORCH_INTERNAL_ASSERT(
365+
weight, "Weight should always be present in serialized qconv.");
366+
367+
torch::List<int64_t> stride, padding, output_padding, dilation;
368+
// skip kSpatialDim
369+
int idx = 1;
370+
for (const auto i : c10::irange(kSpatialDim)) {
371+
(void)i; // Suppress unused variable
372+
stride.emplace_back(config_vals.at(idx));
373+
idx++;
374+
}
375+
for (const auto i : c10::irange(kSpatialDim)) {
376+
(void)i; // Suppress unused variable
377+
padding.emplace_back(config_vals.at(idx));
378+
idx++;
379+
}
380+
for (const auto i : c10::irange(kSpatialDim)) {
381+
(void)i; // Suppress unused variable
382+
dilation.emplace_back(config_vals.at(idx));
383+
idx++;
384+
}
385+
for (const auto i : c10::irange(kSpatialDim)) {
386+
(void)i; // Suppress unused variable
387+
output_padding.emplace_back(config_vals.at(idx));
388+
idx++;
389+
}
390+
int64_t groups = config_vals.at(idx);
391+
idx++;
392+
int64_t flags = config_vals.at(idx);
393+
idx++;
394+
TORCH_INTERNAL_ASSERT(
395+
idx == static_cast<int64_t>(config_vals.size()),
396+
"Unexpected length of config_vals, expected ",
397+
idx,
398+
" got ",
399+
config_vals.size());
400+
401+
bool transpose = flags & (1 << 0);
402+
403+
int64_t other_flags = flags & ~(1 << 0);
404+
TORCH_INTERNAL_ASSERT(
405+
other_flags == 0, "Unexpected flags set in ", flags, ".");
406+
407+
return at::AtenIpexTypeQuantizedXPU::PackedConvWeightQDPCPP<kSpatialDim>::
408+
prepack(
409+
weight.value(),
410+
bias,
411+
stride,
412+
padding,
413+
output_padding,
414+
dilation,
415+
groups,
416+
transpose);
417+
}
418+
#endif

csrc/gpu/utils/Settings.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#include <ATen/native/quantized/PackedParams.h>
21
#include <oneDNN/Runtime.h>
32
#include <runtime/Device.h>
43
#include <utils/Settings.h>
@@ -293,6 +292,14 @@ bool Settings::is_channels_last_1d_enabled() const {
293292
#endif
294293
}
295294

295+
bool Settings::is_jit_quantization_save_enabled() const {
296+
#if defined(BUILD_JIT_QUANTIZATION_SAVE)
297+
return true;
298+
#else
299+
return false;
300+
#endif
301+
}
302+
296303
bool Settings::is_xetla_enabled() const {
297304
#if defined(USE_XETLA)
298305
return true;

csrc/gpu/utils/Settings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class IPEX_API Settings final {
7272
bool is_multi_context_enabled() const;
7373

7474
bool is_channels_last_1d_enabled() const;
75+
bool is_jit_quantization_save_enabled() const;
7576
bool is_xetla_enabled() const;
7677

7778
bool is_simple_trace_enabled() const;

intel_extension_for_pytorch/csrc/xpu/Module.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,10 @@ void init_xpu_module(pybind11::module& m) {
677677
return Settings::I().is_multi_context_enabled();
678678
});
679679

680+
m.def("_is_jit_quantization_save_enabled", []() {
681+
return Settings::I().is_jit_quantization_save_enabled();
682+
});
683+
680684
m.def("_is_channels_last_1d_enabled", []() {
681685
return Settings::I().is_channels_last_1d_enabled();
682686
});

intel_extension_for_pytorch/xpu/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,10 @@ def disable_tile_as_device():
310310
################################################################
311311

312312

313+
def has_jit_quantization_save():
314+
return _C._is_jit_quantization_save_enabled()
315+
316+
313317
def has_xetla():
314318
return _C._is_xetla_enabled()
315319

0 commit comments

Comments
 (0)