Skip to content

Commit 2324119

Browse files
authored
enable the fallback to cpu for linear (#96)
1 parent f38efff commit 2324119

File tree

1 file changed

+50
-9
lines changed

1 file changed

+50
-9
lines changed

torch_ipex/csrc/cpu/CustomOPs.h

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,25 @@ class NewLinearOp : public torch::autograd::Function<NewLinearOp> {
1818
at::Tensor weight,
1919
at::Tensor bias = at::Tensor()) {
2020
ctx->save_for_backward({input, weight, bias});
21-
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
22-
return torch_ipex::cpu::AtenIpexCPUDev::dil_linear(input.is_contiguous() ? input : input.contiguous(), weight.is_contiguous() ? weight : weight.contiguous(), bias.is_contiguous() ? bias : bias.contiguous());
21+
try {
22+
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
23+
return torch_ipex::cpu::AtenIpexCPUDev::dil_linear(input.is_contiguous() ? input : input.contiguous(), weight.is_contiguous() ? weight : weight.contiguous(), bias.is_contiguous() ? bias : bias.contiguous());
24+
}
25+
} catch (std::exception& e) {
26+
#if defined(_DEBUG)
27+
TORCH_WARN(e.what());
28+
#endif
29+
}
30+
if (input.device().type() == c10::DeviceType::DPCPP) {
31+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.layout() == c10::kStrided);
32+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weight.layout() == c10::kStrided);
33+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(bias.layout() == c10::kStrided);
34+
auto&& _ipex_input = torch_ipex::bridge::shallowFallbackToCPUTensor(input);
35+
auto&& _ipex_weight = torch_ipex::bridge::shallowFallbackToCPUTensor(weight);
36+
auto&& _ipex_bias = torch_ipex::bridge::shallowFallbackToCPUTensor(bias);
37+
auto&& _ipex_result = at::linear(_ipex_input, _ipex_weight, _ipex_bias);
38+
static_cast<void>(_ipex_result); // Avoid warnings in case not used
39+
return torch_ipex::bridge::shallowUpgradeToDPCPPTensor(_ipex_result);
2340
} else {
2441
return at::linear(input, weight, bias);
2542
}
@@ -36,20 +53,44 @@ class NewLinearOp : public torch::autograd::Function<NewLinearOp> {
3653
at::Tensor grad_output = grad_outputs[0];
3754
at::Tensor grad_input, grad_weight;
3855
at::Tensor grad_bias = torch::Tensor();
39-
40-
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
41-
grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_input(
42-
input.sizes(), grad_output.is_contiguous() ? grad_output : grad_output.contiguous(), weight.is_contiguous() ? weight : weight.contiguous());
43-
std::tie(grad_weight, grad_bias) = torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_weights(
44-
grad_output.is_contiguous() ? grad_output : grad_output.contiguous(), input.is_contiguous() ? input : input.contiguous(), weight.is_contiguous() ? weight : weight.contiguous(), bias.defined());
56+
57+
try {
58+
if (torch_ipex::check_auto_dnnl() && input.device().type() == c10::DeviceType::DPCPP) {
59+
grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_input(
60+
input.sizes(), grad_output.is_contiguous() ? grad_output : grad_output.contiguous(), weight.is_contiguous() ? weight : weight.contiguous());
61+
std::tie(grad_weight, grad_bias) = torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_weights(
62+
grad_output.is_contiguous() ? grad_output : grad_output.contiguous(), input.is_contiguous() ? input : input.contiguous(), weight.is_contiguous() ? weight : weight.contiguous(), bias.defined());
63+
return {grad_input, grad_weight, grad_bias};
64+
}
65+
} catch (std::exception& e) {
66+
#if defined(_DEBUG)
67+
TORCH_WARN(e.what());
68+
#endif
69+
}
70+
if (input.device().type() == c10::DeviceType::DPCPP) {
71+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.layout() == c10::kStrided);
72+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(weight.layout() == c10::kStrided);
73+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grad_output.layout() == c10::kStrided);
74+
auto&& _ipex_input = torch_ipex::bridge::shallowFallbackToCPUTensor(input);
75+
auto&& _ipex_weight = torch_ipex::bridge::shallowFallbackToCPUTensor(weight);
76+
auto&& _ipex_grad_output = torch_ipex::bridge::shallowFallbackToCPUTensor(grad_output);
77+
grad_input = _ipex_grad_output.mm(_ipex_weight);
78+
grad_weight = _ipex_grad_output.t().mm(_ipex_input);
79+
if (bias.defined()) {
80+
grad_bias = _ipex_grad_output.sum(0);
81+
}
82+
static_cast<void>(grad_input);
83+
static_cast<void>(grad_weight);
84+
static_cast<void>(grad_bias);
85+
return {torch_ipex::bridge::shallowUpgradeToDPCPPTensor(grad_input), torch_ipex::bridge::shallowUpgradeToDPCPPTensor(grad_weight), torch_ipex::bridge::shallowUpgradeToDPCPPTensor(grad_bias)};
4586
} else {
4687
grad_input = grad_output.mm(weight);
4788
grad_weight = grad_output.t().mm(input);
4889
if (bias.defined()) {
4990
grad_bias = grad_output.sum(0);
5091
}
92+
return {grad_input, grad_weight, grad_bias};
5193
}
52-
return {grad_input, grad_weight, grad_bias};
5394
}
5495
};
5596

0 commit comments

Comments
 (0)