diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 836a0c6ef7d..0bd8dae0b66 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -91,7 +91,7 @@ def __init__( self.default_layout: VkMemoryLayout = default_memory_layout self.texture_limits = texture_limits - def propose_node_storage( + def propose_node_storage( # noqa: C901 self, node: torch.fx.Node, ) -> Optional[VkStorageType]: @@ -138,15 +138,23 @@ def propose_node_storage( for arg in node.args: if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): storage = utils.get_node_storage_type(arg) + # Some operators which return multiple output tensors may specify a + # different storage type for each output. In this case, the storage type + # for the first output is used. + if isinstance(storage, (list, tuple)): + storage = storage[0] if storage is not None and storage in valid_storage_types: return storage # If no storage type has been resolved yet, assume the optimal storage type of # the first opinionated user. This search is recursive. for user in node.users: - optimal_storage = self.propose_node_storage(user) - if optimal_storage is not None: - return optimal_storage + storage = self.propose_node_storage(user) + # See above + if isinstance(storage, (list, tuple)): + storage = storage[0] + if storage is not None: + return storage if self.default_storage in valid_storage_types: return self.default_storage @@ -179,15 +187,23 @@ def propose_node_layout( for arg in node.args: if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): layout = utils.get_node_memory_layout(arg) + # Some operators which return multiple output tensors may specify a + # different memory layout for each output. In this case, the storage + # type for the first output is used. + if isinstance(layout, (list, tuple)): + layout = layout[0] if layout is not None and layout in valid_layouts: return layout - # If no storage type has been resolved yet, assume the optimal storage type of - # the first opinionated user. This search is recursive. + # If no memory layout has been resolved yet, assume the optimal layout of the + # first opinionated user. This search is recursive. for user in node.users: - optimal_storage = self.propose_node_layout(user, storage) - if optimal_storage is not None: - return optimal_storage + layout = self.propose_node_layout(user, storage) + # See above comment + if isinstance(layout, (list, tuple)): + layout = layout[0] + if layout is not None: + return layout # As a last resort, return the default storage type that should be used. if self.default_layout in valid_layouts: diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 9333f34430e..0258aceb82b 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -655,6 +655,32 @@ def register_ported_ops_with_prepacking(features: OpFeatures): return features +@update_features( + [ + exir_ops.edge.aten.native_group_norm.default, + ] +) +def register_native_group_norm(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + valid_packed_dims={PackedDim.CHANNELS}, + ) + features.handles_own_prepacking = True + + features.optimal_storage = [ + VkStorageType.TEXTURE_3D, + VkStorageType.BUFFER, + VkStorageType.BUFFER, + ] + + features.optimal_layout = [ + VkMemoryLayout.TENSOR_CHANNELS_PACKED, + VkMemoryLayout.TENSOR_WIDTH_PACKED, + VkMemoryLayout.TENSOR_WIDTH_PACKED, + ] + + return features + + # Ported ops that support their own prepacking. @update_features( [ diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 0096834f3c6..04adf183e55 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1898,3 +1898,69 @@ def forward(self, x): dynamic_shapes=dynamic_shapes, test_inputs=test_inputs, ) + + def test_vulkan_backend_group_norm(self): + class ConvGroupNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + # Conv2d: 3 input channels -> 16 output channels + self.conv = torch.nn.Conv2d( + in_channels=3, + out_channels=16, + kernel_size=3, + padding=1, + bias=True, + ) + # GroupNorm: 4 groups for 16 channels (16 % 4 == 0) + self.group_norm = torch.nn.GroupNorm( + num_groups=4, + num_channels=16, + eps=1e-5, + affine=True, + ) + + def forward(self, x): + x = self.conv(x) + x = self.group_norm(x) + return x + + # Create sample inputs: [batch, channels, height, width] + sample_inputs = (torch.randn(size=(1, 3, 32, 32), dtype=torch.float32),) + + # Test with static shapes first + self.lower_module_and_test_output( + ConvGroupNormModule(), + sample_inputs, + ) + + def test_vulkan_backend_group_norm_different_groups(self): + class GroupNormModule(torch.nn.Module): + def __init__(self, num_groups, num_channels): + super().__init__() + self.group_norm = torch.nn.GroupNorm( + num_groups=num_groups, + num_channels=num_channels, + eps=1e-5, + affine=True, + ) + + def forward(self, x): + return self.group_norm(x) + + # Test different group configurations + test_configs = [ + (2, 8), # 2 groups, 8 channels + (4, 16), # 4 groups, 16 channels + (8, 32), # 8 groups, 32 channels + ] + + for num_groups, num_channels in test_configs: + with self.subTest(num_groups=num_groups, num_channels=num_channels): + sample_inputs = ( + torch.randn(size=(2, num_channels, 16, 16), dtype=torch.float32), + ) + + self.lower_module_and_test_output( + GroupNormModule(num_groups, num_channels), + sample_inputs, + ) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 642f7c5f495..5d57ce1e7be 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -264,9 +264,19 @@ def set_node_spec_attr(node: torch.fx.Node, attr: str, value): if isinstance(spec, TensorSpec): setattr(spec, attr, value) elif isinstance(spec, (list, tuple)): - for s in spec: - assert isinstance(s, TensorSpec) - setattr(s, attr, value) + # Special case if value is a list/tuple of the same length as the + # collection of tensors in the node. In this case, treat the value list + # as a list of values to set indivudually for each tensor in the node + if isinstance(value, (list, tuple)) and len(spec) == len(value): + assert len(spec) == len(value) + for s, v in zip(spec, value): + assert isinstance(s, TensorSpec) + setattr(s, attr, v) + # Otherwise, set the attribute to value for all tensors in the list + else: + for s in spec: + assert isinstance(s, TensorSpec) + setattr(s, attr, value) else: raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}")