Skip to content

Commit b7773b4

Browse files
author
ssjia
committed
Update on "[ET-VK] High dim tensor support for view, unsqueeze, squeeze, clone"
Differential Revision: [D80800084](https://our.internmc.facebook.com/intern/diff/D80800084) [ghstack-poisoned]
2 parents ab88e9b + b964d48 commit b7773b4

File tree

2 files changed

+23
-15
lines changed

2 files changed

+23
-15
lines changed

backends/vulkan/op_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def register_view_ops():
502502

503503
@update_features(
504504
[
505-
exir_ops.edge.aten.view.default,
505+
exir_ops.edge.aten.view_copy.default,
506506
exir_ops.edge.aten.squeeze_copy.dims,
507507
exir_ops.edge.aten.unsqueeze_copy.default,
508508
exir_ops.edge.aten.clone.default,

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,20 +1777,6 @@ def forward(self, x):
17771777
(torch.rand(size=[1, 5, 2, 3]),),
17781778
)
17791779

1780-
def test_vulkan_backend_high_dim_tensors_fail(self):
1781-
class UnsqueezeHigherDim(torch.nn.Module):
1782-
def __init__(self):
1783-
super().__init__()
1784-
1785-
def forward(self, x):
1786-
return torch.unsqueeze(x, 2)
1787-
1788-
self.lower_module_and_test_output(
1789-
UnsqueezeHigherDim(),
1790-
(torch.ones(size=[5, 4, 1, 2, 6]),),
1791-
expect_no_delegates=True,
1792-
)
1793-
17941780
def test_vulkan_backend_large_linear_layer(self):
17951781
class LinearModel(torch.nn.Module):
17961782
def __init__(self, large_out_channels: int) -> None:
@@ -2298,6 +2284,28 @@ def forward(self, x1, x2, x3, x4, x5, x6):
22982284
test_inputs=test_inputs,
22992285
)
23002286

2287+
def test_vulkan_backend_high_dimensional_tensors(self):
2288+
class HighDimTensorModule(torch.nn.Module):
2289+
def __init__(self):
2290+
super().__init__()
2291+
2292+
def forward(self, x, y):
2293+
# Unsqueeze inputs twice to create 5-dim tensors
2294+
x_5d = torch.unsqueeze(torch.unsqueeze(x, 0), 0)
2295+
y_5d = torch.unsqueeze(torch.unsqueeze(y, 0), 0)
2296+
# Add tensors together
2297+
result = x_5d + y_5d
2298+
return result
2299+
2300+
high_dim_module = HighDimTensorModule()
2301+
# Create 2 4-dim inputs
2302+
sample_inputs = (
2303+
torch.rand(size=(2, 3, 4, 5), dtype=torch.float32),
2304+
torch.rand(size=(2, 3, 4, 5), dtype=torch.float32),
2305+
)
2306+
2307+
self.lower_module_and_test_output(high_dim_module, sample_inputs)
2308+
23012309
def test_vulkan_backend_torchao_wo_quantized_linear(self):
23022310
in_features = 1024
23032311
out_features = 512

0 commit comments

Comments
 (0)