Skip to content

Commit 81d2e4a

Browse files
author
ssjia
committed
Update base for Update on "[ET-VK] Contiguous buffer implementation for embedding"
Title says it all! Also moved around some structs throughout the `glslh` files for better code organization. Differential Revision: [D84716456](https://our.internmc.facebook.com/intern/diff/D84716456/) [ghstack-poisoned]
1 parent d69c05c commit 81d2e4a

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

backends/vulkan/utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,32 @@ def tensor_node_is_bool(node: torch.fx.Node) -> bool:
259259
return False
260260

261261

262+
def ndim_of(node: Any) -> Optional[int]:
263+
"""
264+
Returns the number of dimensions of the tensor produced by the given node
265+
"""
266+
if not is_single_tensor_node(node):
267+
return None
268+
269+
return node.meta["val"].ndim
270+
271+
272+
def is_unsqueezed_vector(node: torch.fx.Node) -> bool:
273+
"""
274+
Returns True if the node's tensor has all dimensions equal to 1 except for the last dimension.
275+
"""
276+
if not is_single_tensor_node(node):
277+
return False
278+
279+
tensor = node.meta["val"]
280+
assert isinstance(tensor, FakeTensor)
281+
282+
if len(tensor.shape) < 1:
283+
return False
284+
# All dims except last are 1, last can be any size
285+
return all(dim == 1 for dim in tensor.shape[:-1])
286+
287+
262288
def op_contains_bool_tensor(node: torch.fx.Node) -> bool:
263289
"""
264290
Returns true if the operator used to compute the given node contains a bool tensor
@@ -267,6 +293,7 @@ def op_contains_bool_tensor(node: torch.fx.Node) -> bool:
267293
return True
268294

269295
for arg_node in node.args:
296+
# pyre-ignore[6]
270297
if is_tensor_node(arg_node) and tensor_node_is_bool(arg_node):
271298
return True
272299

@@ -582,6 +609,16 @@ def make_intersect(self, other: "TensorRepSet") -> "TensorRepSet":
582609
self.valid_texture_layouts & other.valid_texture_layouts,
583610
)
584611

612+
def make_union(self, other: "TensorRepSet") -> "TensorRepSet":
613+
"""
614+
Merge this TensorRepSet with another TensorRepSet, returning a new TensorRepSet
615+
with the union of the two.
616+
"""
617+
return TensorRepSet(
618+
self.valid_buffer_layouts | other.valid_buffer_layouts,
619+
self.valid_texture_layouts | other.valid_texture_layouts,
620+
)
621+
585622
def is_compatible(self, storage: TensorRepr) -> bool:
586623
"""
587624
Check if this TensorRepr is compatible with the given TensorRepSet.
@@ -1240,6 +1277,26 @@ def is_in_8bit_range(tensor: torch.Tensor) -> bool:
12401277
##
12411278

12421279

1280+
def normalize_dims(dims: Union[int, List[int]], ndim: int) -> Union[int, List[int]]:
1281+
"""
1282+
Normalize dimension indices to be non-negative and within [0, ndim).
1283+
Accepts a single int or a list of ints.
1284+
"""
1285+
if isinstance(dims, int):
1286+
if dims < 0:
1287+
dims += ndim
1288+
1289+
return dims
1290+
1291+
normalized = []
1292+
for d in dims:
1293+
if d < 0:
1294+
d += ndim
1295+
normalized.append(d)
1296+
1297+
return normalized
1298+
1299+
12431300
def nchw_dim_to_whcn_dim(nchw_dim: int, ndim: int) -> int:
12441301
# Handle negative indices for nchw_dim
12451302
if nchw_dim < 0:

0 commit comments

Comments
 (0)