Skip to content

Commit 487a75e

Browse files
committed
mark param tensor when doing model to dpcpp
1 parent 486c416 commit 487a75e

File tree

5 files changed

+134
-1
lines changed

5 files changed

+134
-1
lines changed

intel_pytorch_extension_py/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from .mlp import *
77
from .jit import *
88
from .save import *
9+
from .to import *
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
import _torch_ipex as core
3+
4+
torch_to = torch.nn.Module.to
5+
6+
def apply(m, fn):
7+
for sub_module in m.children():
8+
apply(sub_module, fn)
9+
fn(m)
10+
return m
11+
12+
def to(module, *args, **kwargs):
13+
m = torch_to(module, *args, **kwargs)
14+
15+
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
16+
17+
if not device or device.type != "dpcpp":
18+
return m
19+
20+
def mark_param(t):
21+
for param in t.parameters():
22+
core.set_parameter_tensor(param.data)
23+
24+
return apply(m, mark_param)
25+
26+
torch.nn.Module.to = to

tests/cpu/test_bf16_lazy_reorder.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,55 @@ def _gen_op(seed, op, is_bn=False, is_forward=True):
7979

8080
return op_cpu, op_auto_mix_inference, op_auto_mix_train, op_man_bf16, op_auto_mix_train_bf16
8181

82+
class CascadedConvBnSumRelu(nn.Module):
83+
def __init__(self, in_channels, mid_channels, out_channels, **kwargs):
84+
super(CascadedConvBnSumRelu, self).__init__()
85+
self.conv = torch.nn.Conv2d(in_channels, mid_channels, bias=False, **kwargs)
86+
self.conv1 = torch.nn.Conv2d(
87+
mid_channels, out_channels, bias=False, padding=1, **kwargs)
88+
self.conv2 = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
89+
self.bn = torch.nn.BatchNorm2d(mid_channels, eps=0.001)
90+
self.bn1 = torch.nn.BatchNorm2d(out_channels, eps=0.001)
91+
self.bn2 = torch.nn.BatchNorm2d(out_channels, eps=0.001)
92+
93+
def forward(self, x):
94+
a = self.conv(x)
95+
a = self.bn(a)
96+
a = F.relu(a, inplace=True)
97+
a = self.conv1(a)
98+
a = self.bn1(a)
99+
b = self.conv2(x)
100+
b = self.bn2(b)
101+
return F.relu(a.add_(b), inplace=True)
102+
103+
def apply(m, fn, args):
104+
for sub_module in m.children():
105+
apply(sub_module, fn, args)
106+
fn(m, args)
107+
108+
class TestTo(TestCase):
109+
def test_to(self):
110+
rand_seed = int(get_rand_seed())
111+
torch.manual_seed(rand_seed)
112+
113+
m = CascadedConvBnSumRelu(3, 64, 32, kernel_size=3, stride=1)
114+
m_cpu = copy.deepcopy(m).to("cpu")
115+
m_data_type = copy.deepcopy(m).to(torch.bfloat16)
116+
m_auto_mix = copy.deepcopy(m).to(device)
117+
m_auto_mix_data_type = copy.deepcopy(m).to(device=device, dtype=torch.bfloat16)
118+
119+
def check_param(t, is_param):
120+
for param in t.parameters():
121+
if is_param:
122+
self.assertTrue(ipex.core.is_parameter_tensor(param.data))
123+
else:
124+
self.assertFalse(ipex.core.is_parameter_tensor(param.data))
125+
126+
apply(m_cpu, check_param, False)
127+
apply(m_data_type, check_param, False)
128+
apply(m_auto_mix, check_param, True)
129+
apply(m_auto_mix_data_type, check_param, True)
130+
82131
class TestConv(TestCase):
83132
def test_Conv2d_with_cpu(self):
84133
rand_seed = int(get_rand_seed())

torch_ipex/csrc/cpu/ShadeDataContext.h

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ enum SHADE_DATA_TYPE {CPU_RAW, DIL};
1616

1717
enum MIX_PREC_TYPE {NONE, MIX_BF16_FP32, MIX_INT8_FP32};
1818

19+
enum SHADE_TENSOR_TAG{PARAM, OTHER};
20+
1921
#define SANITY_CHECK_SHADE_DATA_CONTEXT(THIS) \
2022
{ \
2123
if (THIS->data_type == SHADE_DATA_TYPE::DIL) { \
@@ -52,12 +54,14 @@ struct ShadeDataContext {
5254

5355
SHADE_DATA_TYPE data_type; ///< Memory buffer type
5456
MIX_PREC_TYPE mix_prec_type; ///< Record if the aten tensor is mix-precision
57+
SHADE_TENSOR_TAG shade_tensor_tag; ///< Record if the tensor is a PARAMETER (in mix-precision, never reorder a PARAMETER to bf16)
5558

5659
ShadeDataContext() : dil_tensor(),
5760
cpu_raw_data(nullptr),
5861
cpu_del_fun(nullptr),
5962
data_type(SHADE_DATA_TYPE::CPU_RAW),
60-
mix_prec_type(MIX_PREC_TYPE::NONE) {}
63+
mix_prec_type(MIX_PREC_TYPE::NONE),
64+
shade_tensor_tag(SHADE_TENSOR_TAG::OTHER) {}
6165

6266
~ShadeDataContext() {
6367
SANITY_CHECK_SHADE_DATA_CONTEXT(this);
@@ -216,6 +220,49 @@ struct ShadeDataContext {
216220
return res;
217221
}
218222

223+
/**
224+
* Check if the input aten tensor is a parameter.
225+
*
226+
* @param tensor input aten tensor
227+
*/
228+
static inline bool isParameterTensor(const at::Tensor &tensor) {
229+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.has_storage());
230+
231+
if (tensor.device().type() != c10::DeviceType::DPCPP) {
232+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.device().type() == c10::DeviceType::CPU);
233+
return false;
234+
}
235+
236+
void *storage_context = tensor.storage().data_ptr().get_context();
237+
ShadeDataContext *shade_data_context = (ShadeDataContext*)storage_context;
238+
auto shade_tensor_tag = shade_data_context->shade_tensor_tag;
239+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY((shade_tensor_tag == SHADE_TENSOR_TAG::OTHER) || (shade_tensor_tag == SHADE_TENSOR_TAG::PARAM));
240+
241+
SANITY_CHECK_SHADE_DATA_CONTEXT(shade_data_context);
242+
243+
return shade_tensor_tag == SHADE_TENSOR_TAG::PARAM;
244+
}
245+
246+
/**
247+
* Set the shade_tensor_tag of the input aten tensor to PARAM.
248+
*
249+
* @param tensor input aten tensor
250+
*/
251+
static inline void setParameterTensor(const at::Tensor &tensor) {
252+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.has_storage());
253+
254+
// TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.device().type() == c10::DeviceType::DPCPP);
255+
// TODO: if device is cpu, this function should not be called
256+
if (tensor.device().type() != c10::DeviceType::DPCPP) {
257+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.device().type() == c10::DeviceType::CPU);
258+
return;
259+
}
260+
261+
void *storage_context = tensor.storage().data_ptr().get_context();
262+
ShadeDataContext *shade_data_context = (ShadeDataContext*)storage_context;
263+
shade_data_context->shade_tensor_tag = SHADE_TENSOR_TAG::PARAM;
264+
}
265+
219266
};
220267

221268
} // namespace cpu

torch_ipex/csrc/init_python_bindings.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ void setAutoDNNL(bool val) {
4040
AutoOptConfig::singleton().set_auto_dnnl(val);
4141
}
4242

43+
void setParameterTensor(const at::Tensor &tensor) {
44+
cpu::ShadeDataContext::setParameterTensor(tensor);
45+
}
46+
47+
bool isParameterTensor(const at::Tensor &tensor) {
48+
return cpu::ShadeDataContext::isParameterTensor(tensor);
49+
}
50+
4351
/// **** Only for unit test ****
4452
bool isDilTensor(const at::Tensor &tensor) {
4553
return cpu::ShadeDataContext::isDilTensor(tensor);
@@ -125,6 +133,8 @@ void InitIpexModuleBindings(py::module m) {
125133
m.def("is_fp32_dil_tensor", &isFP32DilTensor);
126134
m.def("get_dil_tensor_sizes", &getDilStorageSizes);
127135
m.def("get_dil_tensor_strides", &getDilStorageStrides);
136+
m.def("set_parameter_tensor", &setParameterTensor);
137+
m.def("is_parameter_tensor", &isParameterTensor);
128138
m.def("enable_jit_opt", []() { AutoOptConfig::singleton().set_jit_fuse(true); });
129139
m.def("disable_jit_opt", []() { AutoOptConfig::singleton().set_jit_fuse(false); });
130140
m.def("get_jit_opt", []() { return AutoOptConfig::singleton().get_jit_fuse(); });

0 commit comments

Comments
 (0)