Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,24 @@

@onnxscript.script(common_opset)
def Rank(input: tensor_typing.TTensor) -> INT64:
"""Take the rank of the input tensor."""
"""Deprecated.

NOTE: Do not remove, for backward compatibility with PyTorch < 2.10.

Take the rank of the input tensor.
"""

return op.Size(op.Shape(input))


@onnxscript.script(common_opset)
def IsScalar(input: tensor_typing.TTensor) -> BOOL:
"""Return whether the input has rank 0, or is a scalar."""
"""Deprecated.

NOTE: Do not remove, for backward compatibility with PyTorch < 2.10.

Return whether the input has rank 0, or is a scalar.
"""

return op.Equal(op.Size(op.Shape(input)), op.Constant(value_int=0))

Expand Down
65 changes: 22 additions & 43 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
_INT64_MAX = 9223372036854775807
_INT64_MIN = -9223372036854775808
_MATH_PI = math.pi
Rank = common_ops.Rank


@torch_op("aten::_local_scalar_dense", trace_only=True)
Expand Down Expand Up @@ -947,11 +946,11 @@ def reshape_to_1d(tensor):
return op.SequenceMap(self, body=reshape_to_1d)


@torch_op("aten::atleast_2d")
@torch_op("aten::atleast_2d", trace_only=True)
def aten_atleast_2d(self: TTensor) -> TTensor:
"""atleast_2d(Tensor self) -> Tensor"""

if Rank(self) <= 1:
if len(self.shape) <= 1:
self = op.Reshape(self, op.Constant(value_ints=[1, -1]))
return op.Identity(self)

Expand All @@ -975,7 +974,7 @@ def reshape_to_2d(tensor):
def aten_atleast_3d(self: TTensor) -> TTensor:
"""atleast_3d(Tensor self) -> Tensor"""

rank = Rank(self)
rank = len(self.shape)
if rank <= 1:
self = op.Reshape(self, op.Constant(value_ints=[1, -1, 1]))
elif rank == 2:
Expand Down Expand Up @@ -1820,39 +1819,21 @@ def aten_conj_physical(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::constant_pad_nd")
def aten_constant_pad_nd(self: TTensor, pad: INT64, value: float = 0.0) -> TTensor:
@torch_op("aten::constant_pad_nd", trace_only=True)
def aten_constant_pad_nd(self: TTensor, pad: Sequence[INT64], value: float = 0.0) -> TTensor:
"""constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor"""

# The desired order of paddings is
# dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
# n is the dimension of input.
# assume zero-dimensions in the beginning
# rank = len(self.shape) # rank must be scalar
# paddings = list(pad[:]) + [0] * (rank * 2 - len(pad))
rank = len(self.shape)
paddings = list(pad) + [0] * (rank * 2 - len(pad))
# reverse order and collate first beginnings and then ends
# paddings = paddings[-2::-2] + paddings[-1::-2]

neg_1 = op.Constant(value_ints=[-1])

zero_count = op.Sub(op.Mul(Rank(self), 2), op.Size(pad))
zero_count = op.Reshape(zero_count, neg_1)
zero = op.Constant(value_ints=[0])
zeros = op.Expand(zero, zero_count)
torch_paddings = op.Concat(pad, zeros, axis=0)
size_d = op.Size(torch_paddings)
steps = op.Constant(value_ints=[-2])

starts = steps
ends = op.Sub(starts, size_d)
odd_elements = op.Slice(torch_paddings, starts, ends, zero, steps)

starts = neg_1
ends = op.Sub(starts, size_d)
even_elements = op.Slice(torch_paddings, starts, ends, zero, steps)
paddings = paddings[-2::-2] + paddings[-1::-2]
constant_value = op.Constant(value=ir.tensor(value, dtype=self.dtype))

onnx_padding = op.Concat(odd_elements, even_elements, axis=0)
return op.Pad(self, onnx_padding, value)
return op.Pad(self, paddings, constant_value)


@torch_op("aten::contiguous", trace_only=True)
Expand Down Expand Up @@ -3996,7 +3977,7 @@ def reshape_to_atleast_2d(tensor):
result = op.ConcatFromSequence(tensors_atleast_2d, axis=1, new_axis=0)

# hstack expects a non-empty sequence of tensors. So we don't need to check for length
rank_1d_or_less = op.Less(Rank(op.SequenceAt(tensors, 0)), 2)
rank_1d_or_less = op.Less(op.Size(op.Shape(op.SequenceAt(tensors, 0))), 2)
if rank_1d_or_less:
result = op.Reshape(result, op.Constant(value_ints=[-1]))
return result
Expand Down Expand Up @@ -6076,7 +6057,7 @@ def aten_native_group_norm(
norm = op.Reshape(norm, op.Shape(input), allowzero=True)
# Using the input weight and bias to do affine
# But need to unsqueeze to the target shape for broading cast easy
input_rank = Rank(input)
input_rank = len(input.shape)
axes_unsqueeze = op.Range(1, input_rank - 1, 1)
weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze)
bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze)
Expand Down Expand Up @@ -8229,7 +8210,7 @@ def aten_symeig(
def aten_t(self: TTensor) -> TTensor:
"""t(Tensor(a) self) -> Tensor(a)"""

rank = Rank(self)
rank = len(self.shape)
if rank == 2:
result = op.Transpose(self, perm=[1, 0])
else:
Expand Down Expand Up @@ -8312,26 +8293,24 @@ def aten_threshold_backward(
raise NotImplementedError()


@torch_op("aten::tile")
def aten_tile(self: TTensor, dims: INT64) -> TTensor:
@torch_op("aten::tile", trace_only=True)
def aten_tile(self: TTensor, dims: Sequence[int]) -> TTensor:
"""tile(Tensor self, int[] dims) -> Tensor"""

self_rank = Rank(self)
dims_rank = op.Size(dims)
diff = op.Sub(self_rank, dims_rank)
self_rank = len(self.shape)
dims_rank = len(dims)
diff = self_rank - dims_rank

if diff > 0:
# dims is shorter than self.shape
# pad dims with 1
diff_1d = op.Reshape(diff, op.Constant(value_ints=[1]))
exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d)
dims = op.Concat(exapnd_ones, dims, axis=0)
exapnd_ones = [1] * diff
dims = [*exapnd_ones, *dims]

if diff < 0:
elif diff < 0:
# dims is longer than self.shape
# pad self.shape with 1
diff_1d = op.Reshape(op.Abs(diff), op.Constant(value_ints=[1]))
exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d)
exapnd_ones = op.Constant(value_ints=[1] * (-diff))
self_shape = op.Shape(self)
self_final_shape = op.Concat(exapnd_ones, self_shape, axis=0)
self = op.Reshape(self, self_final_shape, allowzero=True)
Expand Down
12 changes: 5 additions & 7 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Optional, Sequence, Tuple, TypeVar, Union

from onnxscript import BFLOAT16, BOOL, DOUBLE, FLOAT, FLOAT16, INT64, ir
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from onnxscript.function_libs.torch_lib.registration import torch_op
from onnxscript.function_libs.torch_lib.tensor_typing import (
IntType,
Expand All @@ -32,7 +31,6 @@
from onnxscript.onnx_types import TensorType

_MATH_PI = math.pi
Rank = common_ops.Rank

_INT64_MAX = 9223372036854775807
_INT64_MIN = -9223372036854775808
Expand Down Expand Up @@ -576,7 +574,7 @@ def aten_group_norm(
norm = op.Reshape(norm, op.Shape(input))
# Using the input weight and bias to do affine
# But need to unsqueeze to the target shape for broading cast easy
input_rank = Rank(input)
input_rank = len(input.shape)
one = op.Constant(value_int=1)
axes_unsqueeze = op.Range(one, op.Sub(input_rank, one), one)
weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze)
Expand Down Expand Up @@ -999,7 +997,7 @@ def _aten_max_pool_onnx(
ceil_mode: bool,
unbatched_rank: int,
) -> TFloatOrUInt8:
self_rank_is_unbatched_rank = Rank(self) == unbatched_rank
self_rank_is_unbatched_rank = len(self.shape) == unbatched_rank
if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1
self = op.Unsqueeze(self, [0])

Expand Down Expand Up @@ -1133,7 +1131,7 @@ def _aten_max_pool_with_indices_onnx(
n_dims_zero: Sequence[int],
n_dims_axes: Sequence[int],
) -> Tuple[TFloatOrUInt8, INT64]:
self_rank_is_unbatched_rank = Rank(self) == unbatched_rank
self_rank_is_unbatched_rank = len(self.shape) == unbatched_rank
if self_rank_is_unbatched_rank:
self = op.Unsqueeze(self, axes=[0])

Expand Down Expand Up @@ -1362,11 +1360,11 @@ def aten_nll_loss(
) -> TFloat:
"""nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor"""

self_rank_is_1 = Rank(self) == 1
self_rank_is_1 = len(self.shape) == 1
if self_rank_is_1: # self rank should be at least 2
self = op.Unsqueeze(self, [0])

rank_target = Rank(target)
rank_target = len(target.shape)
if rank_target == 0: # target rank should be at least 1
target = op.Unsqueeze(target, [0])

Expand Down
16 changes: 0 additions & 16 deletions tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import numpy as np
import onnx
import onnx_ir.passes.common as common_passes
import onnxruntime as ort
import onnxruntime.capi.onnxruntime_pybind11_state
import pytest
Expand All @@ -37,7 +36,6 @@
import onnxscript
import onnxscript.evaluator
from onnxscript import ir
from onnxscript.function_libs.torch_lib.ops import common as common_ops
from tests.function_libs.torch_lib import error_reproduction

T = TypeVar("T")
Expand Down Expand Up @@ -412,19 +410,6 @@ def _format_model_and_input_information(onnx_model, inputs):
}


def add_torchlib_common_imports(model: ir.Model) -> None:
"""Hack to add torchlib common imports to the model."""

model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1
rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto())
is_scalar_func = ir.serde.deserialize_function(common_ops.IsScalar.to_function_proto())
model.functions[rank_func.identifier()] = rank_func
model.functions[is_scalar_func.identifier()] = is_scalar_func
removal_pass = common_passes.RemoveUnusedFunctionsPass()
assert removal_pass.in_place
removal_pass(model)


def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool:
"""Checks if the dtype is compatible with the schema.

Expand Down Expand Up @@ -593,7 +578,6 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
proto = onnxscript_function.to_function_proto()
ir_function = ir.serde.deserialize_function(proto)
onnx_model.functions[identifier] = ir_function
add_torchlib_common_imports(onnx_model)
# Make sure the model is valid
model_proto = ir.to_proto(onnx_model)
try:
Expand Down
Loading