Skip to content

Commit fed71ef

Browse files
authored
[Relax] Add native size operator (#18667)
## Why ONNX models use the Size operator to get total element count of a tensor. Relax didn't have a native equivalent. ## How - Adds R.size(tensor) operator that returns the total number of elements in a tensor as a scalar int64
1 parent d8c973e commit fed71ef

File tree

7 files changed

+122
-7
lines changed

7 files changed

+122
-7
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -911,8 +911,7 @@ class Size(OnnxOpConverter):
911911

912912
@classmethod
913913
def _impl_v1(cls, bb, inputs, attr, params):
914-
# TODO(tvm-team): add native support for size op
915-
return relax.op.prod(relax.op.shape_to_tensor(relax.op.shape_of(inputs[0])))
914+
return relax.op.size(inputs[0])
916915

917916

918917
class EyeLike(OnnxOpConverter):

python/tvm/relax/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
register_gradient,
4141
shape_of,
4242
shape_to_tensor,
43+
size,
4344
tensor_to_shape,
4445
to_vdevice,
4546
)

python/tvm/relax/op/base.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,22 @@ def shape_of(expr: Expr) -> Expr:
634634
return _ffi_api.shape_of(expr) # type: ignore # pylint: disable=no-member
635635

636636

637+
def size(expr: Expr) -> Expr:
638+
"""Get the total number of elements in a tensor.
639+
640+
Parameters
641+
----------
642+
expr : Expr
643+
The input tensor.
644+
645+
Returns
646+
-------
647+
result : Expr
648+
A scalar tensor of dtype int64 containing the total number of elements.
649+
"""
650+
return _ffi_api.size(expr) # type: ignore # pylint: disable=no-member
651+
652+
637653
def tensor_to_shape(expr: Expr) -> Expr:
638654
"""Convert tensor to shape expr.
639655
Parameters
@@ -777,11 +793,13 @@ def call_pure_packed(
777793
sinfo_args = [sinfo_args]
778794

779795
sinfo_args = [
780-
sinfo()
781-
if callable(sinfo)
782-
else sinfo.asobject()
783-
if isinstance(sinfo, ObjectConvertible)
784-
else sinfo
796+
(
797+
sinfo()
798+
if callable(sinfo)
799+
else sinfo.asobject()
800+
if isinstance(sinfo, ObjectConvertible)
801+
else sinfo
802+
)
785803
for sinfo in sinfo_args
786804
]
787805

python/tvm/relax/transform/legalize_ops/inspect_op.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from ...block_builder import BlockBuilder
2525
from ...expr import Call, Expr
26+
from ... import op
2627
from .common import register_legalize
2728

2829

@@ -126,3 +127,8 @@ def _get_tensor_elem_offset(dlpack_handle: T.handle) -> T.int64:
126127

127128
gvar = bb.add_func(_get_tensor_elem_offset, "_get_tensor_elem_offset")
128129
return Call(gvar, call.args)
130+
131+
132+
@register_legalize("relax.size")
133+
def _size(_bb: BlockBuilder, call: Call) -> Expr:
134+
return op.prod(op.shape_to_tensor(op.shape_of(call.args[0])))

python/tvm/script/ir_builder/relax/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@
163163
sign,
164164
sin,
165165
sinh,
166+
size,
166167
slice_scatter,
167168
sort,
168169
split,
@@ -938,6 +939,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
938939
"shape",
939940
"shape_of",
940941
"ShapeExpr",
942+
"size",
941943
"std",
942944
"str",
943945
"sum",

src/relax/op/op.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,32 @@ TVM_FFI_STATIC_INIT_BLOCK() {
11251125
refl::GlobalDef().def("relax.op.shape_of", MakeShapeOf);
11261126
}
11271127

1128+
// size
1129+
1130+
StructInfo InferStructInfoSize(const Call& call, const BlockBuilder& ctx) {
1131+
auto arg_sinfo = GetStructInfo(call->args[0]);
1132+
auto* tensor_sinfo = GetStructInfo(call->args[0]).as<TensorStructInfoNode>();
1133+
CHECK(tensor_sinfo) << "size expects a tensor input, but received " << arg_sinfo
1134+
<< "; use MatchCast if necessary";
1135+
return TensorStructInfo(ShapeExpr(ffi::Array<PrimExpr>{}), DataType::Int(64));
1136+
}
1137+
1138+
TVM_REGISTER_OP("relax.size")
1139+
.set_num_inputs(1)
1140+
.add_argument("input", "Expr", "The input tensor")
1141+
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSize)
1142+
.set_attr<Bool>("FPurity", Bool(true));
1143+
1144+
Expr MakeSize(Expr expr) {
1145+
static const Op& op = Op::Get("relax.size");
1146+
return Call(op, {expr}, {}, {});
1147+
}
1148+
1149+
TVM_FFI_STATIC_INIT_BLOCK() {
1150+
namespace refl = tvm::ffi::reflection;
1151+
refl::GlobalDef().def("relax.op.size", MakeSize);
1152+
}
1153+
11281154
// tensor_to_shape
11291155

11301156
StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder& ctx) {

tests/python/relax/test_op_size.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import numpy as np
19+
20+
import tvm
21+
import tvm.testing
22+
from tvm import relax
23+
from tvm.script import relax as R
24+
25+
26+
def test_op_size():
27+
@tvm.script.ir_module
28+
class Module:
29+
@R.function
30+
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((), "int64"):
31+
return R.size(x)
32+
33+
x_np = np.random.rand(2, 3).astype("float32")
34+
x = tvm.runtime.tensor(x_np)
35+
36+
target = tvm.target.Target("llvm")
37+
ex = relax.build(Module, target)
38+
vm = relax.VirtualMachine(ex, tvm.cpu())
39+
40+
res = vm["main"](x)
41+
assert res.numpy() == 6
42+
43+
44+
def test_op_size_dynamic():
45+
@tvm.script.ir_module
46+
class Module:
47+
@R.function
48+
def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor((), "int64"):
49+
return R.size(x)
50+
51+
x_np = np.random.rand(4, 5).astype("float32")
52+
x = tvm.runtime.tensor(x_np)
53+
54+
target = tvm.target.Target("llvm")
55+
ex = relax.build(Module, target)
56+
vm = relax.VirtualMachine(ex, tvm.cpu())
57+
58+
res = vm["main"](x)
59+
assert res.numpy() == 20
60+
61+
62+
if __name__ == "__main__":
63+
tvm.testing.main()

0 commit comments

Comments
 (0)