Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,6 @@ def register_rotary_emb_op(features: OpFeatures):
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.permute.default,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.view_copy.default,
]
)
Expand All @@ -551,6 +549,48 @@ def register_view_ops(features: OpFeatures):
return features


# Fully featured transfer operators (i.e. operators that copy data from the input
# tensor(s) to the output tensor(s)), which have memory layout agnostic implementations
# for both texture and buffer storage types.
@update_features(exir_ops.edge.aten.cat.default)
def register_cat_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
valid_packed_dims=all_packed_dims,
)
features.buffer_impl = True
features.resize_fn = True

def check_cat_node(node: torch.fx.Node) -> bool:
inputs = node.args[0]
if isinstance(inputs, (list, tuple)) and len(inputs) <= 3:
return True

return False

features.check_node_fn = check_cat_node

return features


# Fully featured transfer operators (i.e. operators that copy data from the input
# tensor(s) to the output tensor(s)), which have memory layout agnostic implementations
# for both texture and buffer storage types.
@update_features(
[
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten.slice_copy.Tensor,
]
)
def register_transfer_ops(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
valid_packed_dims=all_packed_dims,
)
features.buffer_impl = True
features.resize_fn = True

return features


# Ops ported from PyTorch Vulkan backend. These ops commonly support channels
# packed tensors only and do not have a resize function.
@update_features(
Expand Down Expand Up @@ -588,7 +628,6 @@ def register_ported_op(features: OpFeatures):
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.unsqueeze_copy.default,
# Tensor combination
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.split.Tensor,
Expand Down
Loading
Loading