Skip to content

Commit 434fcfb

Browse files
committed
type constant
1 parent c9472f3 commit 434fcfb

File tree

1 file changed

+16
-10
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+16
-10
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4468,15 +4468,21 @@ def _aten_index_put_dynamic(
44684468
values: TReal,
44694469
accumulate: bool = False,
44704470
) -> TReal:
4471+
def _1dint(i: int):
4472+
return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i]))
4473+
4474+
def _0dint(i: int):
4475+
return op.Constant(value_int=ir.AttrInt64("value_int", i))
4476+
44714477
def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int):
44724478
if ind is not None:
44734479
return op.Cast(ind, to=INT64.dtype), False
44744480
return (
44754481
op.Cast(
44764482
op.Range( # Range does not return a typed result
4477-
0,
4483+
_0dint(0),
44784484
op.Squeeze(op.Shape(x, start=dim, end=dim + 1)),
4479-
1,
4485+
_0dint(1),
44804486
),
44814487
to=INT64.dtype,
44824488
),
@@ -4500,21 +4506,21 @@ def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int):
45004506
if expanded:
45014507
exped.append((i, ind))
45024508
expand_value_shape.append(op.Shape(x, start=i, end=i + 1))
4503-
reshape_value_shape2.append([1])
4509+
reshape_value_shape2.append(_1dint(1))
45044510
else:
4505-
expand_value_shape.append([1])
4511+
expand_value_shape.append(_1dint(1))
45064512
reshape_value_shape2.append(op.Shape(ind))
45074513
fixed.append((i, ind))
45084514

4509-
reshape_value_shape1 = [1] * len(indices)
4515+
reshape_value_shape1 = [_1dint(1)] * len(indices)
45104516
if len(fixed) <= 1:
45114517
reshape_value_shape1 = None
45124518
elif fixed:
4513-
reshape_value_shape1[fixed[-1][0]] = -1
4519+
reshape_value_shape1[fixed[-1][0]] = _1dint(-1)
45144520

45154521
def _mkstride(x, i):
45164522
if i >= len(x.shape) - 1:
4517-
return [1]
4523+
return _1dint(1)
45184524
if i == len(x.shape) - 2:
45194525
return op.Shape(x, start=i + 1)
45204526
return op.ReduceProd(op.Shape(x, start=i + 1), keepdims=1)
@@ -4547,9 +4553,9 @@ def _mkstride(x, i):
45474553
# Bug here: Error calling operator 'Concat' with args
45484554
# (SymbolicTensor(name='anonymous:124529632436112', producer=anonymous_node:124529631522416, index=0), [1], [1])
45494555
expanded_values = op.Expand(expanded_values, op.Concat(*expand_value_shape, axis=0))
4550-
flat_ind = op.Reshape(unflat, [-1])
4551-
expanded_values = op.Reshape(expanded_values, [-1])
4552-
flat_x = op.Reshape(x, [-1])
4556+
flat_ind = op.Reshape(unflat, _1dint(-1))
4557+
expanded_values = op.Reshape(expanded_values, _1dint(-1))
4558+
flat_x = op.Reshape(x, _1dint(-1))
45534559
scat_kwargs = {"reduction": "add"} if accumulate else {}
45544560
flat_up_x = op.ScatterElements(flat_x, flat_ind, expanded_values, **scat_kwargs)
45554561
return op.Reshape(flat_up_x, op.Shape(x))

0 commit comments

Comments
 (0)