|
1 | 1 | from __future__ import annotations
|
2 | 2 | from collections.abc import Callable, Sequence
|
3 |
| -from functools import partial, reduce |
| 3 | +from functools import partial, reduce, update_wrapper |
4 | 4 | from numbers import Number
|
5 | 5 | from types import EllipsisType, NoneType
|
6 | 6 | from typing import Any, Union
|
|
10 | 10 |
|
11 | 11 | from thunder.clang.langctx import register_method
|
12 | 12 | from thunder.clang.utils import create_maybe_convert_to_dtype_with_prim, _elementwise_unary_wrapper
|
| 13 | +import thunder.clang.utils as clang_utils |
13 | 14 | from thunder.core import utils
|
14 | 15 | from thunder.core.baseutils import run_once
|
15 | 16 | from thunder.core.langctxs import langctx, Languages
|
@@ -368,33 +369,13 @@ def diagonal(a: TensorLike, offset: int = 0, dim1: int = 0, dim2: int = 1) -> Te
|
368 | 369 |
|
369 | 370 | # Expands a to the specified shape, possibly adding new dimensions and expanding
|
370 | 371 | # dimensions of length 1 to any length
|
371 |
| -@clangop() |
372 |
| -def expand(a: TensorLike, *shape: int) -> TensorLike: |
373 |
| - shape = utils.extract_shape_from_varargs(shape) |
374 |
| - |
375 |
| - # TODO: improve this error message with error context |
376 |
| - utils.check( |
377 |
| - len(shape) >= len(a.shape), |
378 |
| - lambda: "expand: the requested shape has too few dimensions!", |
379 |
| - ) |
380 |
| - |
381 |
| - offset = len(shape) - len(a.shape) |
382 |
| - shape_ = list(shape) |
383 |
| - for idx, x in enumerate(a.shape): |
384 |
| - offset_idx = idx + offset |
385 |
| - requested_length = shape[offset_idx] |
386 |
| - utils.check( |
387 |
| - requested_length == x or x == 1 or requested_length == -1, |
388 |
| - lambda: f"expand: attempting to expand a dimension of length {x}!", |
389 |
| - ) |
390 |
| - |
391 |
| - shape_[offset_idx] = requested_length if requested_length != -1 else x |
392 |
| - |
393 |
| - # At this point shape must be valid |
394 |
| - # utils.check_valid_shape(shape_) |
| 372 | +expand = clangop()(partial(clang_utils.expand_impl, broadcast_prim=prims.broadcast_in_dim)) |
| 373 | +# To preserve the docstring |
| 374 | +update_wrapper(expand, clang_utils.expand_impl) |
395 | 375 |
|
396 |
| - # NOTE: Converting shape_ to tuple makes it possible to apply CSE |
397 |
| - return prims.broadcast_in_dim(a, tuple(shape_), tuple(range(offset, len(a.shape) + offset))) |
| 376 | +maybe_broadcast = clangop()(partial(clang_utils.maybe_broadcast_impl, expand_fn=expand)) |
| 377 | +# To preserve the docstring |
| 378 | +update_wrapper(maybe_broadcast, clang_utils.maybe_broadcast_impl) |
398 | 379 |
|
399 | 380 |
|
400 | 381 | # TODO Resolve the start & end vs. start & stop inconsistencies with our operators (this one is start & end)
|
@@ -1085,31 +1066,7 @@ def stack(tensors: list[TensorProxy], dim: int):
|
1085 | 1066 | return cat(tensors_, dim)
|
1086 | 1067 |
|
1087 | 1068 |
|
1088 |
| -@clangop() |
1089 |
| -def compute_broadcast_shape(*_shapes): |
1090 |
| - """Computes the common shape with the fewest dimensions that all input shapes can be broadcast to.""" |
1091 |
| - shapes = tuple(x for x in filter(lambda x: x is not None, _shapes)) |
1092 |
| - |
1093 |
| - # Short-circuits if there are no inputs shapes |
1094 |
| - # This might happen in calls like add(2, 3) |
1095 |
| - if len(shapes) == 0: |
1096 |
| - return None |
1097 |
| - |
1098 |
| - common_shape = [ |
1099 |
| - 1, |
1100 |
| - ] * reduce(max, (len(shape) for shape in shapes)) |
1101 |
| - |
1102 |
| - for shape in shapes: |
1103 |
| - for idx in range(-1, -1 - len(shape), -1): |
1104 |
| - if common_shape[idx] == 1: |
1105 |
| - common_shape[idx] = shape[idx] |
1106 |
| - |
1107 |
| - utils.check( |
1108 |
| - (shape[idx] == 1) or (common_shape[idx] == shape[idx]), |
1109 |
| - lambda: f"Attempting to broadcast a dimension of length {shape[idx]}!", |
1110 |
| - ) |
1111 |
| - |
1112 |
| - return tuple(common_shape) |
| 1069 | +compute_broadcast_shape = clangop()(clang_utils.compute_broadcast_shape) |
1113 | 1070 |
|
1114 | 1071 |
|
1115 | 1072 | @run_once
|
@@ -1155,28 +1112,6 @@ def matrix_transpose(a: TensorProxy) -> TensorProxy:
|
1155 | 1112 | return transpose(a, permutation)
|
1156 | 1113 |
|
1157 | 1114 |
|
1158 |
| -# TODO: add scalar support |
1159 |
| -# TODO: review hasattr pattern |
1160 |
| -# NOTE: the tensor is not broadcasted if it is a CPU scalar tensor and treat_cpu_scalar_tensors_as_numbers=True |
1161 |
| -@clangop() |
1162 |
| -def maybe_broadcast(*args, treat_cpu_scalar_tensors_as_numbers=True): |
1163 |
| - """Returns tensors with the same shape, possibly broadcasting inputs to the result shape.""" |
1164 |
| - |
1165 |
| - # Computes common shape |
1166 |
| - common_shape = compute_broadcast_shape(*map(lambda t: t.shape if hasattr(t, "shape") else None, args)) |
1167 |
| - |
1168 |
| - def _maybe_broadcast(x, shape): |
1169 |
| - if treat_cpu_scalar_tensors_as_numbers and utils.is_cpu_scalar_tensor(x): |
1170 |
| - return x |
1171 |
| - if hasattr(x, "shape"): |
1172 |
| - if not utils.same_shape(x.shape, common_shape): |
1173 |
| - return expand(x, common_shape) |
1174 |
| - |
1175 |
| - return x |
1176 |
| - |
1177 |
| - return tuple(_maybe_broadcast(x, common_shape) for x in args) |
1178 |
| - |
1179 |
| - |
1180 | 1115 | #
|
1181 | 1116 | # Elementwise unary operations
|
1182 | 1117 | #
|
|
0 commit comments