Skip to content
Merged
34 changes: 25 additions & 9 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
self.default_layout: VkMemoryLayout = default_memory_layout
self.texture_limits = texture_limits

def propose_node_storage(
def propose_node_storage( # noqa: C901
self,
node: torch.fx.Node,
) -> Optional[VkStorageType]:
Expand Down Expand Up @@ -138,15 +138,23 @@ def propose_node_storage(
for arg in node.args:
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):
storage = utils.get_node_storage_type(arg)
# Some operators which return multiple output tensors may specify a
# different storage type for each output. In this case, the storage type
# for the first output is used.
if isinstance(storage, (list, tuple)):
storage = storage[0]
if storage is not None and storage in valid_storage_types:
return storage

# If no storage type has been resolved yet, assume the optimal storage type of
# the first opinionated user. This search is recursive.
for user in node.users:
optimal_storage = self.propose_node_storage(user)
if optimal_storage is not None:
return optimal_storage
storage = self.propose_node_storage(user)
# See above
if isinstance(storage, (list, tuple)):
storage = storage[0]
if storage is not None:
return storage

if self.default_storage in valid_storage_types:
return self.default_storage
Expand Down Expand Up @@ -179,15 +187,23 @@ def propose_node_layout(
for arg in node.args:
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):
layout = utils.get_node_memory_layout(arg)
# Some operators which return multiple output tensors may specify a
# different memory layout for each output. In this case, the storage
# type for the first output is used.
if isinstance(layout, (list, tuple)):
layout = layout[0]
if layout is not None and layout in valid_layouts:
return layout

# If no storage type has been resolved yet, assume the optimal storage type of
# the first opinionated user. This search is recursive.
# If no memory layout has been resolved yet, assume the optimal layout of the
# first opinionated user. This search is recursive.
for user in node.users:
optimal_storage = self.propose_node_layout(user, storage)
if optimal_storage is not None:
return optimal_storage
layout = self.propose_node_layout(user, storage)
# See above comment
if isinstance(layout, (list, tuple)):
layout = layout[0]
if layout is not None:
return layout

# As a last resort, return the default storage type that should be used.
if self.default_layout in valid_layouts:
Expand Down
26 changes: 26 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,32 @@ def register_ported_ops_with_prepacking(features: OpFeatures):
return features


@update_features(
[
exir_ops.edge.aten.native_group_norm.default,
]
)
def register_native_group_norm(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
valid_packed_dims={PackedDim.CHANNELS},
)
features.handles_own_prepacking = True

features.optimal_storage = [
VkStorageType.TEXTURE_3D,
VkStorageType.BUFFER,
VkStorageType.BUFFER,
]

features.optimal_layout = [
VkMemoryLayout.TENSOR_CHANNELS_PACKED,
VkMemoryLayout.TENSOR_WIDTH_PACKED,
VkMemoryLayout.TENSOR_WIDTH_PACKED,
]

return features


# Ported ops that support their own prepacking.
@update_features(
[
Expand Down
66 changes: 66 additions & 0 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,3 +1898,69 @@ def forward(self, x):
dynamic_shapes=dynamic_shapes,
test_inputs=test_inputs,
)

def test_vulkan_backend_group_norm(self):
class ConvGroupNormModule(torch.nn.Module):
def __init__(self):
super().__init__()
# Conv2d: 3 input channels -> 16 output channels
self.conv = torch.nn.Conv2d(
in_channels=3,
out_channels=16,
kernel_size=3,
padding=1,
bias=True,
)
# GroupNorm: 4 groups for 16 channels (16 % 4 == 0)
self.group_norm = torch.nn.GroupNorm(
num_groups=4,
num_channels=16,
eps=1e-5,
affine=True,
)

def forward(self, x):
x = self.conv(x)
x = self.group_norm(x)
return x

# Create sample inputs: [batch, channels, height, width]
sample_inputs = (torch.randn(size=(1, 3, 32, 32), dtype=torch.float32),)

# Test with static shapes first
self.lower_module_and_test_output(
ConvGroupNormModule(),
sample_inputs,
)

def test_vulkan_backend_group_norm_different_groups(self):
class GroupNormModule(torch.nn.Module):
def __init__(self, num_groups, num_channels):
super().__init__()
self.group_norm = torch.nn.GroupNorm(
num_groups=num_groups,
num_channels=num_channels,
eps=1e-5,
affine=True,
)

def forward(self, x):
return self.group_norm(x)

# Test different group configurations
test_configs = [
(2, 8), # 2 groups, 8 channels
(4, 16), # 4 groups, 16 channels
(8, 32), # 8 groups, 32 channels
]

for num_groups, num_channels in test_configs:
with self.subTest(num_groups=num_groups, num_channels=num_channels):
sample_inputs = (
torch.randn(size=(2, num_channels, 16, 16), dtype=torch.float32),
)

self.lower_module_and_test_output(
GroupNormModule(num_groups, num_channels),
sample_inputs,
)
16 changes: 13 additions & 3 deletions backends/vulkan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,19 @@ def set_node_spec_attr(node: torch.fx.Node, attr: str, value):
if isinstance(spec, TensorSpec):
setattr(spec, attr, value)
elif isinstance(spec, (list, tuple)):
for s in spec:
assert isinstance(s, TensorSpec)
setattr(s, attr, value)
# Special case if value is a list/tuple of the same length as the
# collection of tensors in the node. In this case, treat the value list
# as a list of values to set indivudually for each tensor in the node
if isinstance(value, (list, tuple)) and len(spec) == len(value):
assert len(spec) == len(value)
for s, v in zip(spec, value):
assert isinstance(s, TensorSpec)
setattr(s, attr, v)
# Otherwise, set the attribute to value for all tensors in the list
else:
for s in spec:
assert isinstance(s, TensorSpec)
setattr(s, attr, value)
else:
raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}")

Expand Down
Loading