From f43525f31256f0cef3393283e3eb13d802450dff Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Mon, 27 Oct 2025 17:20:34 -0700 Subject: [PATCH 1/2] Continue refactoring into quantized_disapatchers --- sharktank/sharktank/ops/default_impls.py | 40 ++----- sharktank/sharktank/ops/quantized_impls.py | 132 ++++++--------------- 2 files changed, 46 insertions(+), 126 deletions(-) diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index b2ccc7c5a75..c2aa376847a 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -456,35 +456,17 @@ def gather_default( return torch.gather(unbox_tensor(input), dim, unbox_tensor(index)) -@extract_slice.override(AllOfType(Tensor, PrimitiveTensor)) -def extract_slice_default(tensor, key): - return unbox_tensor(tensor)[key] - - -@extract_slice.override(QuantizedTensor) -def extract_slice_QuantizedTensor(tensor: QuantizedTensor, key: slice): - unpacked = tensor.unpack() - if isinstance(unpacked, BlockScaledI4Layout): - mul = 2 - new_d = unpacked._d[key] - new_qs = unpacked._qs[key] - if unpacked.m is not None: - new_m = unpacked.m[key] - dims = new_qs.shape - dims = dims[:-2] + (dims[-2] * dims[-1] * mul,) - layout = BlockScaledI4Layout(shape=dims, d=new_d, qs=new_qs, m=new_m) - return PlanarQuantizedTensor(shape=dims, layout=layout) - elif isinstance(unpacked, TensorScaledLayout): - d = unpacked._d - qs = unpacked._qs[key] - if unpacked._m.dim() == 0: - m = unpacked._m - else: - m = unpacked._m[key] - shape = qs.shape - layout = TensorScaledLayout(shape=shape, d=d, qs=qs, m=m) - return PlanarQuantizedTensor(shape=shape, layout=layout) - return NotImplemented +@extract_slice.override(BlockScaledI4Layout) +def extract_slice_BlockScaledI4Layout(layout: BlockScaledI4Layout, key: slice): + mul = 2 + new_d = layout._d[key] + new_qs = layout._qs[key] + if layout.m is not None: + new_m = layout.m[key] + dims = new_qs.shape + dims = dims[:-2] + (dims[-2] * dims[-1] * mul,) + new_layout = BlockScaledI4Layout(shape=dims, d=new_d, qs=new_qs, m=new_m) + return PlanarQuantizedTensor(shape=dims, layout=new_layout) @gemm.override(AllOfType(Tensor, InferenceTensor)) diff --git a/sharktank/sharktank/ops/quantized_impls.py b/sharktank/sharktank/ops/quantized_impls.py index 8639f26e054..90bfead96f6 100644 --- a/sharktank/sharktank/ops/quantized_impls.py +++ b/sharktank/sharktank/ops/quantized_impls.py @@ -5,15 +5,13 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import functools import math -import inspect import torch import warnings from collections.abc import Sequence from copy import deepcopy -from typing import Any, Callable +from typing import Callable from torch import Tensor from ._registry import * from sharktank.types import ( @@ -53,78 +51,6 @@ import iree.turbine.ops.iree -def quantized_tensor_layout_of_type( - *layout_types: tuple[QuantizedLayout | None], - **kw_layout_types: dict[str, QuantizedLayout | None], -) -> Callable[..., Any]: - """Decorator that check that the arguments have the expected QuantizedLayout. - - If the arguments have the expected layout call the function. If not, return NotImplemented. - - E.g. - ``` - @my_fn.override(QuantizedTensor) - @quantized_tensor_layout_of_type(a=BlockScaledFp4Layout, b=SuperBlockOffsetScaled_4_6_Layout) - def my_fn_impl(a: QuantizedTensor, b: QuantizedTensor): - ... - ``` - - """ - - def decorator(f: Callable[..., Any]): - signature = inspect.signature(f) - - @functools.wraps(f) - def wrapper(*args, **kwargs): - # torch.export doesn't play nicely with inspect - if torch._dynamo.is_compiling(): - return f(*args, **kwargs) - - bound_arguments = signature.bind(*args, **kwargs) - bound_layout_types = signature.bind_partial( - *layout_types, **kw_layout_types - ) - for k, layout_type in bound_layout_types.arguments.items(): - if layout_type is None: - continue - if signature.parameters[k].kind == inspect.Parameter.VAR_POSITIONAL: - if any( - not isinstance(arg.to_planar().layout, l_type) - for l_type, arg in zip( - layout_type, bound_arguments.arguments[k] - ) - ): - return NotImplemented - if signature.parameters[k].kind == inspect.Parameter.VAR_KEYWORD: - if any( - not isinstance( - bound_arguments.arguments[k][name].to_planar().layout, - l_type, - ) - for name, l_type in layout_type.items() - ): - return NotImplemented - if not isinstance( - bound_arguments.arguments[k].to_planar().layout, layout_type - ): - return NotImplemented - - # All tensors have the expected layout, we can make the call. - return f(*args, **kwargs) - - wrapper._layout_types = {} - if layout_types: - param_names = list(signature.parameters.keys()) - wrapper._layout_types.update( - dict(zip(param_names[: len(layout_types)], layout_types)) - ) - if kw_layout_types: - wrapper._layout_types.update(kw_layout_types) - return wrapper - - return decorator - - def verify_quantized_shape(actual: tuple[int, ...], expected: tuple[int, ...]): assert iterables_equal( actual, expected @@ -417,15 +343,20 @@ def cat_BlockScaledFp4Layout(tensors: Sequence[PlanarQuantizedTensor], dim: int) @extract_slice.override(PlanarQuantizedTensor) -@quantized_tensor_layout_of_type(tensor=BlockScaledFp4Layout) -def extract_slice_BlockScaledFp4Layout(tensor: PlanarQuantizedTensor, key: Slice): - layout: BlockScaledFp4Layout = tensor.layout - slice_ = canonicalize_slice_descriptor(squeeze_slice(key), tensor.shape) +def extract_slice_quantized_dispatcher( + tensor: PlanarQuantizedTensor, key: Slice +) -> PlanarQuantizedTensor: + return extract_slice(tensor.unpack(), key) + + +@extract_slice.override(BlockScaledFp4Layout) +def extract_slice_BlockScaledFp4Layout(layout: BlockScaledFp4Layout, key: Slice): + slice_ = canonicalize_slice_descriptor(squeeze_slice(key), layout.shape) assert all( isinstance(s, slice) for s in slice_ ), "Slicing with integers like tensor[1, 2, [3, 4]] is not supported. Only ranges are supported." block_shape = tuple( - tensor.shape[i] // layout.d.shape[i] for i in range(len(tensor.shape)) + layout.shape[i] // layout.d.shape[i] for i in range(len(layout.shape)) ) assert ( math.prod(block_shape) == layout.block_size @@ -477,38 +408,45 @@ def extract_slice_BlockScaledFp4Layout(tensor: PlanarQuantizedTensor, key: Slice ) -@extract_slice.override(PlanarQuantizedTensor) -@quantized_tensor_layout_of_type(tensor=TensorScaledLayout) +@extract_slice.override(TensorScaledLayout) def extract_slice_TensorScaledLayout( - tensor: PlanarQuantizedTensor, key: Slice + layout: TensorScaledLayout, key: Slice ) -> PlanarQuantizedTensor: - planes = dict(tensor.layout.planes) + planes = dict(layout.planes) planes["qs"] = extract_slice(planes["qs"], key) - metadata = dict(tensor.layout.metadata) - metadata["shape"] = tensor.shape + metadata = dict(layout.metadata) + metadata["shape"] = layout.shape return PlanarQuantizedTensor( - shape=tensor.shape, - layout=type(tensor.layout).create( - shape=tensor.layout.shape, metadata=metadata, planes=planes + shape=layout.shape, + layout=type(layout).create( + shape=layout.shape, metadata=metadata, planes=planes ), ) -@split.override(QuantizedTensor) -@quantized_tensor_layout_of_type(tensor=BlockScaledFp4Layout) +@split.override(PlanarQuantizedTensor) +def split_quantized_dispatcher( + tensor: PlanarQuantizedTensor, + split_size_or_sections: int | list[int], + dim: int = 0, +) -> tuple[PlanarQuantizedTensor, ...]: + return split(tensor.unpack(), split_size_or_sections, dim) + + +@split.override(BlockScaledFp4Layout) def split_BlockScaledFp4Layout( - tensor: QuantizedTensor, + layout: BlockScaledFp4Layout, split_size_or_sections: int | list[int], dim: int = 0, -) -> tuple[QuantizedTensor, ...]: - dim = normalize_negative_dim(tensor, dim) - dim_size = tensor.shape[dim] +) -> tuple[PlanarQuantizedTensor, ...]: + dim = normalize_negative_dim(layout.shape, dim) + dim_size = layout.shape[dim] if isinstance(split_size_or_sections, int): sections = [split_size_or_sections] * (dim_size // split_size_or_sections) reminder = dim_size % split_size_or_sections if reminder != 0: sections.append(reminder) - return split_BlockScaledFp4Layout(tensor, sections, dim) + return split_BlockScaledFp4Layout(layout, sections, dim) assert len(split_size_or_sections) > 0 parts_range = [(0, split_size_or_sections[0])] @@ -521,7 +459,7 @@ def split_BlockScaledFp4Layout( slice_ = tuple( slice(begin, end) if i == dim else slice(None) for i in range(dim + 1) ) - res.append(tensor[slice_]) + res.append(extract_slice(layout, slice_)) return tuple(res) From 63bcb96282b7d105cd99b50458465073336749a0 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Tue, 28 Oct 2025 10:29:08 -0700 Subject: [PATCH 2/2] Try-catch necessary because of redispatching --- sharktank/sharktank/ops/default_impls.py | 25 ++++++++++++++++++---- sharktank/sharktank/ops/quantized_impls.py | 10 +++++++-- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index c2aa376847a..c90e4c65e20 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -385,7 +385,10 @@ def expand_default(tensor: AnyTensor, shape: List[int]) -> AnyTensor: def expand_quantized_dispatcher( tensor: PlanarQuantizedTensor, shape: List[int] ) -> PlanarQuantizedTensor: - return expand(tensor.unpack(), shape) + try: + return expand(tensor.unpack(), shape) + except NotImplementedError: + return NotImplemented @expand.override(TensorScaledLayout) @@ -414,7 +417,10 @@ def flatten_default( def flatten_quantized_dispatcher( tensor: PlanarQuantizedTensor, start_dim: int, end_dim: int ) -> PlanarQuantizedTensor: - return flatten(tensor.unpack(), start_dim, end_dim) + try: + return flatten(tensor.unpack(), start_dim, end_dim) + except NotImplementedError: + return NotImplemented @flatten.override(TensorScaledLayout) @@ -456,6 +462,11 @@ def gather_default( return torch.gather(unbox_tensor(input), dim, unbox_tensor(index)) +@extract_slice.override(AllOfType(Tensor, PrimitiveTensor)) +def extract_slice_default(tensor, key): + return unbox_tensor(tensor)[key] + + @extract_slice.override(BlockScaledI4Layout) def extract_slice_BlockScaledI4Layout(layout: BlockScaledI4Layout, key: slice): mul = 2 @@ -995,7 +1006,10 @@ def unsqueeze_default(tensor: Union[Tensor, PrimitiveTensor], dim: int) -> Tenso def unsqueeze_quantized_dispatcher( tensor: PlanarQuantizedTensor, dim: int ) -> PlanarQuantizedTensor: - return unsqueeze(tensor.unpack(), dim) + try: + return unsqueeze(tensor.unpack(), dim) + except NotImplementedError: + return NotImplemented @unsqueeze.override(TensorScaledLayout) @@ -1179,7 +1193,10 @@ def view_default( @view.override(PlanarQuantizedTensor) def view_quantized_dispatcher(tensor: PlanarQuantizedTensor, shape, dtype): - return view(tensor.unpack(), shape, dtype) + try: + return view(tensor.unpack(), shape, dtype) + except NotImplementedError: + return NotImplemented @view.override(TensorScaledLayout) diff --git a/sharktank/sharktank/ops/quantized_impls.py b/sharktank/sharktank/ops/quantized_impls.py index 90bfead96f6..ef7ed999223 100644 --- a/sharktank/sharktank/ops/quantized_impls.py +++ b/sharktank/sharktank/ops/quantized_impls.py @@ -346,7 +346,10 @@ def cat_BlockScaledFp4Layout(tensors: Sequence[PlanarQuantizedTensor], dim: int) def extract_slice_quantized_dispatcher( tensor: PlanarQuantizedTensor, key: Slice ) -> PlanarQuantizedTensor: - return extract_slice(tensor.unpack(), key) + try: + return extract_slice(tensor.unpack(), key) + except NotImplementedError: + return NotImplemented @extract_slice.override(BlockScaledFp4Layout) @@ -430,7 +433,10 @@ def split_quantized_dispatcher( split_size_or_sections: int | list[int], dim: int = 0, ) -> tuple[PlanarQuantizedTensor, ...]: - return split(tensor.unpack(), split_size_or_sections, dim) + try: + return split(tensor.unpack(), split_size_or_sections, dim) + except NotImplementedError: + return NotImplemented @split.override(BlockScaledFp4Layout)