Skip to content

Commit 167a3bc

Browse files
pytorchbothinriksnaer
authored andcommitted
[ET-VK] Allow specifying multiple storage types/memory layouts for an operator + register group norm operator (pytorch#11974)
## Changes * Handle cases where an operator needs to specify a separate storage type / memory layout for each individual output. ## Motivation Required for the group norm operator. ## Future Work Currently, the `tag_memory_meta_pass` graph pass assumes that all tensors participating in a computation (aside from weights) will have the same storage type and memory layout. As more operators are being added, there are more exceptions to this rule. The pass may need an update in the near future to make it possible to specify required storage types and memory layouts on a more granular level. Differential Revision: [D77038781](https://our.internmc.facebook.com/intern/diff/D77038781/)
1 parent 8d12f03 commit 167a3bc

File tree

4 files changed

+130
-12
lines changed

4 files changed

+130
-12
lines changed

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(
9191
self.default_layout: VkMemoryLayout = default_memory_layout
9292
self.texture_limits = texture_limits
9393

94-
def propose_node_storage(
94+
def propose_node_storage( # noqa: C901
9595
self,
9696
node: torch.fx.Node,
9797
) -> Optional[VkStorageType]:
@@ -138,15 +138,23 @@ def propose_node_storage(
138138
for arg in node.args:
139139
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):
140140
storage = utils.get_node_storage_type(arg)
141+
# Some operators which return multiple output tensors may specify a
142+
# different storage type for each output. In this case, the storage type
143+
# for the first output is used.
144+
if isinstance(storage, (list, tuple)):
145+
storage = storage[0]
141146
if storage is not None and storage in valid_storage_types:
142147
return storage
143148

144149
# If no storage type has been resolved yet, assume the optimal storage type of
145150
# the first opinionated user. This search is recursive.
146151
for user in node.users:
147-
optimal_storage = self.propose_node_storage(user)
148-
if optimal_storage is not None:
149-
return optimal_storage
152+
storage = self.propose_node_storage(user)
153+
# See above
154+
if isinstance(storage, (list, tuple)):
155+
storage = storage[0]
156+
if storage is not None:
157+
return storage
150158

151159
if self.default_storage in valid_storage_types:
152160
return self.default_storage
@@ -179,15 +187,23 @@ def propose_node_layout(
179187
for arg in node.args:
180188
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):
181189
layout = utils.get_node_memory_layout(arg)
190+
# Some operators which return multiple output tensors may specify a
191+
# different memory layout for each output. In this case, the storage
192+
# type for the first output is used.
193+
if isinstance(layout, (list, tuple)):
194+
layout = layout[0]
182195
if layout is not None and layout in valid_layouts:
183196
return layout
184197

185-
# If no storage type has been resolved yet, assume the optimal storage type of
186-
# the first opinionated user. This search is recursive.
198+
# If no memory layout has been resolved yet, assume the optimal layout of the
199+
# first opinionated user. This search is recursive.
187200
for user in node.users:
188-
optimal_storage = self.propose_node_layout(user, storage)
189-
if optimal_storage is not None:
190-
return optimal_storage
201+
layout = self.propose_node_layout(user, storage)
202+
# See above comment
203+
if isinstance(layout, (list, tuple)):
204+
layout = layout[0]
205+
if layout is not None:
206+
return layout
191207

192208
# As a last resort, return the default storage type that should be used.
193209
if self.default_layout in valid_layouts:

backends/vulkan/op_registry.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,32 @@ def register_ported_ops_with_prepacking(features: OpFeatures):
655655
return features
656656

657657

658+
@update_features(
659+
[
660+
exir_ops.edge.aten.native_group_norm.default,
661+
]
662+
)
663+
def register_native_group_norm(features: OpFeatures):
664+
features.texture_impl = TextureImplFeatures(
665+
valid_packed_dims={PackedDim.CHANNELS},
666+
)
667+
features.handles_own_prepacking = True
668+
669+
features.optimal_storage = [
670+
VkStorageType.TEXTURE_3D,
671+
VkStorageType.BUFFER,
672+
VkStorageType.BUFFER,
673+
]
674+
675+
features.optimal_layout = [
676+
VkMemoryLayout.TENSOR_CHANNELS_PACKED,
677+
VkMemoryLayout.TENSOR_WIDTH_PACKED,
678+
VkMemoryLayout.TENSOR_WIDTH_PACKED,
679+
]
680+
681+
return features
682+
683+
658684
# Ported ops that support their own prepacking.
659685
@update_features(
660686
[

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,3 +1898,69 @@ def forward(self, x):
18981898
dynamic_shapes=dynamic_shapes,
18991899
test_inputs=test_inputs,
19001900
)
1901+
1902+
def test_vulkan_backend_group_norm(self):
1903+
class ConvGroupNormModule(torch.nn.Module):
1904+
def __init__(self):
1905+
super().__init__()
1906+
# Conv2d: 3 input channels -> 16 output channels
1907+
self.conv = torch.nn.Conv2d(
1908+
in_channels=3,
1909+
out_channels=16,
1910+
kernel_size=3,
1911+
padding=1,
1912+
bias=True,
1913+
)
1914+
# GroupNorm: 4 groups for 16 channels (16 % 4 == 0)
1915+
self.group_norm = torch.nn.GroupNorm(
1916+
num_groups=4,
1917+
num_channels=16,
1918+
eps=1e-5,
1919+
affine=True,
1920+
)
1921+
1922+
def forward(self, x):
1923+
x = self.conv(x)
1924+
x = self.group_norm(x)
1925+
return x
1926+
1927+
# Create sample inputs: [batch, channels, height, width]
1928+
sample_inputs = (torch.randn(size=(1, 3, 32, 32), dtype=torch.float32),)
1929+
1930+
# Test with static shapes first
1931+
self.lower_module_and_test_output(
1932+
ConvGroupNormModule(),
1933+
sample_inputs,
1934+
)
1935+
1936+
def test_vulkan_backend_group_norm_different_groups(self):
1937+
class GroupNormModule(torch.nn.Module):
1938+
def __init__(self, num_groups, num_channels):
1939+
super().__init__()
1940+
self.group_norm = torch.nn.GroupNorm(
1941+
num_groups=num_groups,
1942+
num_channels=num_channels,
1943+
eps=1e-5,
1944+
affine=True,
1945+
)
1946+
1947+
def forward(self, x):
1948+
return self.group_norm(x)
1949+
1950+
# Test different group configurations
1951+
test_configs = [
1952+
(2, 8), # 2 groups, 8 channels
1953+
(4, 16), # 4 groups, 16 channels
1954+
(8, 32), # 8 groups, 32 channels
1955+
]
1956+
1957+
for num_groups, num_channels in test_configs:
1958+
with self.subTest(num_groups=num_groups, num_channels=num_channels):
1959+
sample_inputs = (
1960+
torch.randn(size=(2, num_channels, 16, 16), dtype=torch.float32),
1961+
)
1962+
1963+
self.lower_module_and_test_output(
1964+
GroupNormModule(num_groups, num_channels),
1965+
sample_inputs,
1966+
)

backends/vulkan/utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,19 @@ def set_node_spec_attr(node: torch.fx.Node, attr: str, value):
264264
if isinstance(spec, TensorSpec):
265265
setattr(spec, attr, value)
266266
elif isinstance(spec, (list, tuple)):
267-
for s in spec:
268-
assert isinstance(s, TensorSpec)
269-
setattr(s, attr, value)
267+
# Special case if value is a list/tuple of the same length as the
268+
# collection of tensors in the node. In this case, treat the value list
269+
# as a list of values to set indivudually for each tensor in the node
270+
if isinstance(value, (list, tuple)) and len(spec) == len(value):
271+
assert len(spec) == len(value)
272+
for s, v in zip(spec, value):
273+
assert isinstance(s, TensorSpec)
274+
setattr(s, attr, v)
275+
# Otherwise, set the attribute to value for all tensors in the list
276+
else:
277+
for s in spec:
278+
assert isinstance(s, TensorSpec)
279+
setattr(s, attr, value)
270280
else:
271281
raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}")
272282

0 commit comments

Comments
 (0)