@@ -18,8 +18,25 @@ class NewLinearOp : public torch::autograd::Function<NewLinearOp> {
1818 at::Tensor weight,
1919 at::Tensor bias = at::Tensor()) {
2020 ctx->save_for_backward ({input, weight, bias});
21- if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
22- return torch_ipex::cpu::AtenIpexCPUDev::dil_linear (input.is_contiguous () ? input : input.contiguous (), weight.is_contiguous () ? weight : weight.contiguous (), bias.is_contiguous () ? bias : bias.contiguous ());
21+ try {
22+ if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
23+ return torch_ipex::cpu::AtenIpexCPUDev::dil_linear (input.is_contiguous () ? input : input.contiguous (), weight.is_contiguous () ? weight : weight.contiguous (), bias.is_contiguous () ? bias : bias.contiguous ());
24+ }
25+ } catch (std::exception& e) {
26+ #if defined(_DEBUG)
27+ TORCH_WARN (e.what ());
28+ #endif
29+ }
30+ if (input.device ().type () == c10::DeviceType::DPCPP) {
31+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (input.layout () == c10::kStrided );
32+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (weight.layout () == c10::kStrided );
33+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (bias.layout () == c10::kStrided );
34+ auto && _ipex_input = torch_ipex::bridge::shallowFallbackToCPUTensor (input);
35+ auto && _ipex_weight = torch_ipex::bridge::shallowFallbackToCPUTensor (weight);
36+ auto && _ipex_bias = torch_ipex::bridge::shallowFallbackToCPUTensor (bias);
37+ auto && _ipex_result = at::linear (_ipex_input, _ipex_weight, _ipex_bias);
38+ static_cast <void >(_ipex_result); // Avoid warnings in case not used
39+ return torch_ipex::bridge::shallowUpgradeToDPCPPTensor (_ipex_result);
2340 } else {
2441 return at::linear (input, weight, bias);
2542 }
@@ -36,20 +53,44 @@ class NewLinearOp : public torch::autograd::Function<NewLinearOp> {
3653 at::Tensor grad_output = grad_outputs[0 ];
3754 at::Tensor grad_input, grad_weight;
3855 at::Tensor grad_bias = torch::Tensor ();
39-
40- if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
41- grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_input (
42- input.sizes (), grad_output.is_contiguous () ? grad_output : grad_output.contiguous (), weight.is_contiguous () ? weight : weight.contiguous ());
43- std::tie (grad_weight, grad_bias) = torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_weights (
44- grad_output.is_contiguous () ? grad_output : grad_output.contiguous (), input.is_contiguous () ? input : input.contiguous (), weight.is_contiguous () ? weight : weight.contiguous (), bias.defined ());
56+
57+ try {
58+ if (torch_ipex::check_auto_dnnl () && input.device ().type () == c10::DeviceType::DPCPP) {
59+ grad_input = torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_input (
60+ input.sizes (), grad_output.is_contiguous () ? grad_output : grad_output.contiguous (), weight.is_contiguous () ? weight : weight.contiguous ());
61+ std::tie (grad_weight, grad_bias) = torch_ipex::cpu::AtenIpexCPUDev::dil_linear_backward_weights (
62+ grad_output.is_contiguous () ? grad_output : grad_output.contiguous (), input.is_contiguous () ? input : input.contiguous (), weight.is_contiguous () ? weight : weight.contiguous (), bias.defined ());
63+ return {grad_input, grad_weight, grad_bias};
64+ }
65+ } catch (std::exception& e) {
66+ #if defined(_DEBUG)
67+ TORCH_WARN (e.what ());
68+ #endif
69+ }
70+ if (input.device ().type () == c10::DeviceType::DPCPP) {
71+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (input.layout () == c10::kStrided );
72+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (weight.layout () == c10::kStrided );
73+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (grad_output.layout () == c10::kStrided );
74+ auto && _ipex_input = torch_ipex::bridge::shallowFallbackToCPUTensor (input);
75+ auto && _ipex_weight = torch_ipex::bridge::shallowFallbackToCPUTensor (weight);
76+ auto && _ipex_grad_output = torch_ipex::bridge::shallowFallbackToCPUTensor (grad_output);
77+ grad_input = _ipex_grad_output.mm (_ipex_weight);
78+ grad_weight = _ipex_grad_output.t ().mm (_ipex_input);
79+ if (bias.defined ()) {
80+ grad_bias = _ipex_grad_output.sum (0 );
81+ }
82+ static_cast <void >(grad_input);
83+ static_cast <void >(grad_weight);
84+ static_cast <void >(grad_bias);
85+ return {torch_ipex::bridge::shallowUpgradeToDPCPPTensor (grad_input), torch_ipex::bridge::shallowUpgradeToDPCPPTensor (grad_weight), torch_ipex::bridge::shallowUpgradeToDPCPPTensor (grad_bias)};
4586 } else {
4687 grad_input = grad_output.mm (weight);
4788 grad_weight = grad_output.t ().mm (input);
4889 if (bias.defined ()) {
4990 grad_bias = grad_output.sum (0 );
5091 }
92+ return {grad_input, grad_weight, grad_bias};
5193 }
52- return {grad_input, grad_weight, grad_bias};
5394 }
5495};
5596
0 commit comments