Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 11 additions & 29 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,35 +456,17 @@ def gather_default(
return torch.gather(unbox_tensor(input), dim, unbox_tensor(index))


@extract_slice.override(AllOfType(Tensor, PrimitiveTensor))
def extract_slice_default(tensor, key):
return unbox_tensor(tensor)[key]


@extract_slice.override(QuantizedTensor)
def extract_slice_QuantizedTensor(tensor: QuantizedTensor, key: slice):
unpacked = tensor.unpack()
if isinstance(unpacked, BlockScaledI4Layout):
mul = 2
new_d = unpacked._d[key]
new_qs = unpacked._qs[key]
if unpacked.m is not None:
new_m = unpacked.m[key]
dims = new_qs.shape
dims = dims[:-2] + (dims[-2] * dims[-1] * mul,)
layout = BlockScaledI4Layout(shape=dims, d=new_d, qs=new_qs, m=new_m)
return PlanarQuantizedTensor(shape=dims, layout=layout)
elif isinstance(unpacked, TensorScaledLayout):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might look like this is deleted, but there's another copy in quantized_impls. I do not know why.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yet another reason to refactor all the variants for a single op into one place.

d = unpacked._d
qs = unpacked._qs[key]
if unpacked._m.dim() == 0:
m = unpacked._m
else:
m = unpacked._m[key]
shape = qs.shape
layout = TensorScaledLayout(shape=shape, d=d, qs=qs, m=m)
return PlanarQuantizedTensor(shape=shape, layout=layout)
return NotImplemented
@extract_slice.override(BlockScaledI4Layout)
def extract_slice_BlockScaledI4Layout(layout: BlockScaledI4Layout, key: slice):
mul = 2
new_d = layout._d[key]
new_qs = layout._qs[key]
if layout.m is not None:
new_m = layout.m[key]
dims = new_qs.shape
dims = dims[:-2] + (dims[-2] * dims[-1] * mul,)
new_layout = BlockScaledI4Layout(shape=dims, d=new_d, qs=new_qs, m=new_m)
return PlanarQuantizedTensor(shape=dims, layout=new_layout)


@gemm.override(AllOfType(Tensor, InferenceTensor))
Expand Down
132 changes: 35 additions & 97 deletions sharktank/sharktank/ops/quantized_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception


import functools
import math
import inspect
import torch
import warnings

from collections.abc import Sequence
from copy import deepcopy
from typing import Any, Callable
from typing import Callable
from torch import Tensor
from ._registry import *
from sharktank.types import (
Expand Down Expand Up @@ -53,78 +51,6 @@
import iree.turbine.ops.iree


def quantized_tensor_layout_of_type(
*layout_types: tuple[QuantizedLayout | None],
**kw_layout_types: dict[str, QuantizedLayout | None],
) -> Callable[..., Any]:
"""Decorator that check that the arguments have the expected QuantizedLayout.

If the arguments have the expected layout call the function. If not, return NotImplemented.

E.g.
```
@my_fn.override(QuantizedTensor)
@quantized_tensor_layout_of_type(a=BlockScaledFp4Layout, b=SuperBlockOffsetScaled_4_6_Layout)
def my_fn_impl(a: QuantizedTensor, b: QuantizedTensor):
...
```

"""

def decorator(f: Callable[..., Any]):
signature = inspect.signature(f)

@functools.wraps(f)
def wrapper(*args, **kwargs):
# torch.export doesn't play nicely with inspect
if torch._dynamo.is_compiling():
return f(*args, **kwargs)

bound_arguments = signature.bind(*args, **kwargs)
bound_layout_types = signature.bind_partial(
*layout_types, **kw_layout_types
)
for k, layout_type in bound_layout_types.arguments.items():
if layout_type is None:
continue
if signature.parameters[k].kind == inspect.Parameter.VAR_POSITIONAL:
if any(
not isinstance(arg.to_planar().layout, l_type)
for l_type, arg in zip(
layout_type, bound_arguments.arguments[k]
)
):
return NotImplemented
if signature.parameters[k].kind == inspect.Parameter.VAR_KEYWORD:
if any(
not isinstance(
bound_arguments.arguments[k][name].to_planar().layout,
l_type,
)
for name, l_type in layout_type.items()
):
return NotImplemented
if not isinstance(
bound_arguments.arguments[k].to_planar().layout, layout_type
):
return NotImplemented

# All tensors have the expected layout, we can make the call.
return f(*args, **kwargs)

wrapper._layout_types = {}
if layout_types:
param_names = list(signature.parameters.keys())
wrapper._layout_types.update(
dict(zip(param_names[: len(layout_types)], layout_types))
)
if kw_layout_types:
wrapper._layout_types.update(kw_layout_types)
return wrapper

return decorator


def verify_quantized_shape(actual: tuple[int, ...], expected: tuple[int, ...]):
assert iterables_equal(
actual, expected
Expand Down Expand Up @@ -417,15 +343,20 @@ def cat_BlockScaledFp4Layout(tensors: Sequence[PlanarQuantizedTensor], dim: int)


@extract_slice.override(PlanarQuantizedTensor)
@quantized_tensor_layout_of_type(tensor=BlockScaledFp4Layout)
def extract_slice_BlockScaledFp4Layout(tensor: PlanarQuantizedTensor, key: Slice):
layout: BlockScaledFp4Layout = tensor.layout
slice_ = canonicalize_slice_descriptor(squeeze_slice(key), tensor.shape)
def extract_slice_quantized_dispatcher(
tensor: PlanarQuantizedTensor, key: Slice
) -> PlanarQuantizedTensor:
return extract_slice(tensor.unpack(), key)


@extract_slice.override(BlockScaledFp4Layout)
def extract_slice_BlockScaledFp4Layout(layout: BlockScaledFp4Layout, key: Slice):
slice_ = canonicalize_slice_descriptor(squeeze_slice(key), layout.shape)
assert all(
isinstance(s, slice) for s in slice_
), "Slicing with integers like tensor[1, 2, [3, 4]] is not supported. Only ranges are supported."
block_shape = tuple(
tensor.shape[i] // layout.d.shape[i] for i in range(len(tensor.shape))
layout.shape[i] // layout.d.shape[i] for i in range(len(layout.shape))
)
assert (
math.prod(block_shape) == layout.block_size
Expand Down Expand Up @@ -477,38 +408,45 @@ def extract_slice_BlockScaledFp4Layout(tensor: PlanarQuantizedTensor, key: Slice
)


@extract_slice.override(PlanarQuantizedTensor)
@quantized_tensor_layout_of_type(tensor=TensorScaledLayout)
@extract_slice.override(TensorScaledLayout)
def extract_slice_TensorScaledLayout(
tensor: PlanarQuantizedTensor, key: Slice
layout: TensorScaledLayout, key: Slice
) -> PlanarQuantizedTensor:
planes = dict(tensor.layout.planes)
planes = dict(layout.planes)
planes["qs"] = extract_slice(planes["qs"], key)
metadata = dict(tensor.layout.metadata)
metadata["shape"] = tensor.shape
metadata = dict(layout.metadata)
metadata["shape"] = layout.shape
return PlanarQuantizedTensor(
shape=tensor.shape,
layout=type(tensor.layout).create(
shape=tensor.layout.shape, metadata=metadata, planes=planes
shape=layout.shape,
layout=type(layout).create(
shape=layout.shape, metadata=metadata, planes=planes
),
)


@split.override(QuantizedTensor)
@quantized_tensor_layout_of_type(tensor=BlockScaledFp4Layout)
@split.override(PlanarQuantizedTensor)
def split_quantized_dispatcher(
tensor: PlanarQuantizedTensor,
split_size_or_sections: int | list[int],
dim: int = 0,
) -> tuple[PlanarQuantizedTensor, ...]:
return split(tensor.unpack(), split_size_or_sections, dim)


@split.override(BlockScaledFp4Layout)
def split_BlockScaledFp4Layout(
tensor: QuantizedTensor,
layout: BlockScaledFp4Layout,
split_size_or_sections: int | list[int],
dim: int = 0,
) -> tuple[QuantizedTensor, ...]:
dim = normalize_negative_dim(tensor, dim)
dim_size = tensor.shape[dim]
) -> tuple[PlanarQuantizedTensor, ...]:
dim = normalize_negative_dim(layout.shape, dim)
dim_size = layout.shape[dim]
if isinstance(split_size_or_sections, int):
sections = [split_size_or_sections] * (dim_size // split_size_or_sections)
reminder = dim_size % split_size_or_sections
if reminder != 0:
sections.append(reminder)
return split_BlockScaledFp4Layout(tensor, sections, dim)
return split_BlockScaledFp4Layout(layout, sections, dim)

assert len(split_size_or_sections) > 0
parts_range = [(0, split_size_or_sections[0])]
Expand All @@ -521,7 +459,7 @@ def split_BlockScaledFp4Layout(
slice_ = tuple(
slice(begin, end) if i == dim else slice(None) for i in range(dim + 1)
)
res.append(tensor[slice_])
res.append(extract_slice(layout, slice_))
return tuple(res)


Expand Down
Loading