Skip to content

Commit 49c5190

Browse files
authored
Inductor: register meta function for linear_pointwise (#4791) (#4966)
* Inductor: register meta function for linear_pointwise * refine format
1 parent 9d489a8 commit 49c5190

File tree

3 files changed

+90
-24
lines changed

3 files changed

+90
-24
lines changed

csrc/gpu/aten/operators/Linear.cpp

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -416,26 +416,6 @@ Tensor dpcpp_linear(
416416
return linear_wrapper.call(result, input, weight, bias, post_op);
417417
}
418418

419-
Tensor linear_pointwise_meta(
420-
const Tensor& input_t,
421-
const Tensor& weight_t,
422-
const c10::optional<Tensor>& bias_opt,
423-
c10::string_view attr,
424-
torch::List<c10::optional<at::Scalar>> scalars,
425-
c10::optional<c10::string_view> algorithm) {
426-
Attr att;
427-
att = construct_unary_attr(attr, scalars, algorithm, att);
428-
Tensor result;
429-
Tensor _bias = bias_opt.has_value() ? bias_opt.value() : at::Tensor();
430-
Tensor _input = input_t.dim() <= 2
431-
? input_t
432-
: torch_ipex::xpu::oneDNN::contiguous_if_needed(input_t);
433-
bool is_fused = false;
434-
result = matmul_fusion_variants_meta(
435-
result, _input, weight_t, true, att, /*is_fused*/ is_fused, _bias);
436-
return result;
437-
}
438-
439419
Tensor linear_pointwise(
440420
const Tensor& input_t,
441421
const Tensor& weight_t,
@@ -479,10 +459,6 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
479459
}
480460

481461
TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
482-
m.impl(
483-
TORCH_SELECTIVE_NAME("torch_ipex::_linear_pointwise"),
484-
c10::DispatchKey::Meta,
485-
TORCH_FN(linear_pointwise_meta));
486462
m.impl(
487463
TORCH_SELECTIVE_NAME("torch_ipex::_linear_pointwise"),
488464
c10::DispatchKey::XPU,

intel_extension_for_pytorch/_inductor/xpu/_meta_registrations.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,10 @@ def meta_torch_ipex_convolution_binary_inplace(
8080
unary_algorithm,
8181
):
8282
return other
83+
84+
85+
@register_meta("_linear_pointwise", "default")
86+
def meta_torch_ipex_linear_pointwise_default(
87+
input_tensor, weight, bias, attr, scalars, algorithm
88+
):
89+
return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0]))
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch._dynamo
4+
from torch._inductor import config
5+
6+
from torch.testing._internal.common_utils import TestCase
7+
8+
import intel_extension_for_pytorch # noqa
9+
import numpy as np
10+
import pytest
11+
import platform
12+
13+
np.set_printoptions(threshold=np.inf)
14+
15+
cpu_device = torch.device("cpu")
16+
dpcpp_device = torch.device("xpu")
17+
18+
19+
# The dict value is match_nodes(computation_op+unary_op)
20+
# See pytorch/test/inductor/test_mkldnn_pattern_matcher.py
21+
unary_list = {
22+
torch.nn.ReLU(): 2,
23+
torch.nn.Sigmoid(): 2,
24+
torch.nn.Tanh(): 2,
25+
torch.nn.Hardswish(): 6,
26+
torch.nn.LeakyReLU(0.1, inplace=False): 4,
27+
torch.nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False): 3,
28+
torch.nn.Hardtanh(min_val=-0.5, max_val=float("inf"), inplace=False): 3,
29+
torch.nn.GELU(approximate="none"): 6,
30+
torch.nn.GELU(approximate="tanh"): 10,
31+
torch.nn.ReLU6(): 3,
32+
torch.nn.SiLU(): 3,
33+
torch.nn.Hardsigmoid(): 5,
34+
}
35+
36+
# The dict value is (match_count, match_nodes, inplace)
37+
# See pytorch/test/inductor/test_mkldnn_pattern_matcher.py
38+
binary_list = {
39+
lambda x, y: torch.add(x, y): (1, 2, False), # call_function
40+
lambda x, y: torch.add(y, x): (1, 2, False),
41+
lambda x, y: x.add(y): (1, 2, False), # call_method
42+
lambda x, y: x.add_(y): (1, 2, False),
43+
lambda x, y: torch.sub(x, y): (1, 2, False), # call_function
44+
lambda x, y: x.sub(y): (1, 2, False), # call_method
45+
lambda x, y: x.sub_(y): (1, 2, True), # call_method
46+
}
47+
48+
49+
class N(nn.Module):
50+
def __init__(self, in_channels, out_channels, unary_fn, **kwargs):
51+
super(N, self).__init__()
52+
self.linear = torch.nn.Linear(in_channels, out_channels, **kwargs)
53+
self.unary_fn = unary_fn
54+
55+
def forward(self, x):
56+
x = self.linear(x)
57+
x = self.unary_fn(x)
58+
return x
59+
60+
61+
class TestTorchMethod(TestCase):
62+
@pytest.mark.skipif(
63+
platform.system() == "Windows" or "WSL2" in platform.uname().release,
64+
reason="Windows not yet supported for torch.compile",
65+
)
66+
@config.patch({"freezing": True})
67+
def test_inductor_fusion_linear(self):
68+
called = False
69+
device = dpcpp_device
70+
for unary_fn in unary_list:
71+
for dynam in [True, False]:
72+
model = N(3, 4, unary_fn, bias=False).to(device)
73+
74+
model.eval()
75+
with torch.no_grad():
76+
# with torch.xpu.onednn_verbose(2):
77+
run = torch.compile(model, backend="inductor", dynamic=dynam)
78+
torch.manual_seed(0)
79+
example_input = torch.randn(3, 3).to(device)
80+
actual = run(example_input)
81+
82+
ref = model(example_input)
83+
self.assertEqual(ref, actual)

0 commit comments

Comments
 (0)