Skip to content

Commit 8c09745

Browse files
authored
[ET-VK] Allow specifying multiple storage types/memory layouts for an operator + register group norm operator
Differential Revision: D77038781 Pull Request resolved: #11828
1 parent bb805ad commit 8c09745

File tree

4 files changed

+130
-12
lines changed

4 files changed

+130
-12
lines changed

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(
9191
self.default_layout: VkMemoryLayout = default_memory_layout
9292
self.texture_limits = texture_limits
9393

94-
def propose_node_storage(
94+
def propose_node_storage( # noqa: C901
9595
self,
9696
node: torch.fx.Node,
9797
) -> Optional[VkStorageType]:
@@ -138,15 +138,23 @@ def propose_node_storage(
138138
for arg in node.args:
139139
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):
140140
storage = utils.get_node_storage_type(arg)
141+
# Some operators which return multiple output tensors may specify a
142+
# different storage type for each output. In this case, the storage type
143+
# for the first output is used.
144+
if isinstance(storage, (list, tuple)):
145+
storage = storage[0]
141146
if storage is not None and storage in valid_storage_types:
142147
return storage
143148

144149
# If no storage type has been resolved yet, assume the optimal storage type of
145150
# the first opinionated user. This search is recursive.
146151
for user in node.users:
147-
optimal_storage = self.propose_node_storage(user)
148-
if optimal_storage is not None:
149-
return optimal_storage
152+
storage = self.propose_node_storage(user)
153+
# See above
154+
if isinstance(storage, (list, tuple)):
155+
storage = storage[0]
156+
if storage is not None:
157+
return storage
150158

151159
if self.default_storage in valid_storage_types:
152160
return self.default_storage
@@ -179,15 +187,23 @@ def propose_node_layout(
179187
for arg in node.args:
180188
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):
181189
layout = utils.get_node_memory_layout(arg)
190+
# Some operators which return multiple output tensors may specify a
191+
# different memory layout for each output. In this case, the storage
192+
# type for the first output is used.
193+
if isinstance(layout, (list, tuple)):
194+
layout = layout[0]
182195
if layout is not None and layout in valid_layouts:
183196
return layout
184197

185-
# If no storage type has been resolved yet, assume the optimal storage type of
186-
# the first opinionated user. This search is recursive.
198+
# If no memory layout has been resolved yet, assume the optimal layout of the
199+
# first opinionated user. This search is recursive.
187200
for user in node.users:
188-
optimal_storage = self.propose_node_layout(user, storage)
189-
if optimal_storage is not None:
190-
return optimal_storage
201+
layout = self.propose_node_layout(user, storage)
202+
# See above comment
203+
if isinstance(layout, (list, tuple)):
204+
layout = layout[0]
205+
if layout is not None:
206+
return layout
191207

192208
# As a last resort, return the default storage type that should be used.
193209
if self.default_layout in valid_layouts:

backends/vulkan/op_registry.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,32 @@ def register_ported_ops_with_prepacking(features: OpFeatures):
655655
return features
656656

657657

658+
@update_features(
659+
[
660+
exir_ops.edge.aten.native_group_norm.default,
661+
]
662+
)
663+
def register_native_group_norm(features: OpFeatures):
664+
features.texture_impl = TextureImplFeatures(
665+
valid_packed_dims={PackedDim.CHANNELS},
666+
)
667+
features.handles_own_prepacking = True
668+
669+
features.optimal_storage = [
670+
VkStorageType.TEXTURE_3D,
671+
VkStorageType.BUFFER,
672+
VkStorageType.BUFFER,
673+
]
674+
675+
features.optimal_layout = [
676+
VkMemoryLayout.TENSOR_CHANNELS_PACKED,
677+
VkMemoryLayout.TENSOR_WIDTH_PACKED,
678+
VkMemoryLayout.TENSOR_WIDTH_PACKED,
679+
]
680+
681+
return features
682+
683+
658684
# Ported ops that support their own prepacking.
659685
@update_features(
660686
[

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,3 +1898,69 @@ def forward(self, x):
18981898
dynamic_shapes=dynamic_shapes,
18991899
test_inputs=test_inputs,
19001900
)
1901+
1902+
def test_vulkan_backend_group_norm(self):
1903+
class ConvGroupNormModule(torch.nn.Module):
1904+
def __init__(self):
1905+
super().__init__()
1906+
# Conv2d: 3 input channels -> 16 output channels
1907+
self.conv = torch.nn.Conv2d(
1908+
in_channels=3,
1909+
out_channels=16,
1910+
kernel_size=3,
1911+
padding=1,
1912+
bias=True,
1913+
)
1914+
# GroupNorm: 4 groups for 16 channels (16 % 4 == 0)
1915+
self.group_norm = torch.nn.GroupNorm(
1916+
num_groups=4,
1917+
num_channels=16,
1918+
eps=1e-5,
1919+
affine=True,
1920+
)
1921+
1922+
def forward(self, x):
1923+
x = self.conv(x)
1924+
x = self.group_norm(x)
1925+
return x
1926+
1927+
# Create sample inputs: [batch, channels, height, width]
1928+
sample_inputs = (torch.randn(size=(1, 3, 32, 32), dtype=torch.float32),)
1929+
1930+
# Test with static shapes first
1931+
self.lower_module_and_test_output(
1932+
ConvGroupNormModule(),
1933+
sample_inputs,
1934+
)
1935+
1936+
def test_vulkan_backend_group_norm_different_groups(self):
1937+
class GroupNormModule(torch.nn.Module):
1938+
def __init__(self, num_groups, num_channels):
1939+
super().__init__()
1940+
self.group_norm = torch.nn.GroupNorm(
1941+
num_groups=num_groups,
1942+
num_channels=num_channels,
1943+
eps=1e-5,
1944+
affine=True,
1945+
)
1946+
1947+
def forward(self, x):
1948+
return self.group_norm(x)
1949+
1950+
# Test different group configurations
1951+
test_configs = [
1952+
(2, 8), # 2 groups, 8 channels
1953+
(4, 16), # 4 groups, 16 channels
1954+
(8, 32), # 8 groups, 32 channels
1955+
]
1956+
1957+
for num_groups, num_channels in test_configs:
1958+
with self.subTest(num_groups=num_groups, num_channels=num_channels):
1959+
sample_inputs = (
1960+
torch.randn(size=(2, num_channels, 16, 16), dtype=torch.float32),
1961+
)
1962+
1963+
self.lower_module_and_test_output(
1964+
GroupNormModule(num_groups, num_channels),
1965+
sample_inputs,
1966+
)

backends/vulkan/utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,19 @@ def set_node_spec_attr(node: torch.fx.Node, attr: str, value):
264264
if isinstance(spec, TensorSpec):
265265
setattr(spec, attr, value)
266266
elif isinstance(spec, (list, tuple)):
267-
for s in spec:
268-
assert isinstance(s, TensorSpec)
269-
setattr(s, attr, value)
267+
# Special case if value is a list/tuple of the same length as the
268+
# collection of tensors in the node. In this case, treat the value list
269+
# as a list of values to set indivudually for each tensor in the node
270+
if isinstance(value, (list, tuple)) and len(spec) == len(value):
271+
assert len(spec) == len(value)
272+
for s, v in zip(spec, value):
273+
assert isinstance(s, TensorSpec)
274+
setattr(s, attr, v)
275+
# Otherwise, set the attribute to value for all tensors in the list
276+
else:
277+
for s in spec:
278+
assert isinstance(s, TensorSpec)
279+
setattr(s, attr, value)
270280
else:
271281
raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}")
272282

0 commit comments

Comments
 (0)