From 9b206e9dd37ee2892009f9416e428cce577fe80c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 14:31:59 -0700 Subject: [PATCH 01/23] WIP Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 127 +----------------- onnxscript/function_libs/torch_lib/ops/nn.py | 80 ++--------- 2 files changed, 17 insertions(+), 190 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1a688a4277..dc87e172af 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7517,99 +7517,11 @@ def aten_rnn_tanh_cell( raise NotImplementedError() -@torch_op("aten::roll", trace_only=True) +# roll is decomposed by PyTorch def aten_roll(self: TTensor, shifts: Sequence[int], dims: Sequence[int] = ()) -> TTensor: """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor""" - self_rank = len(self.shape) - if self_rank == 0: - return op.Identity(self) - elif self.shape[0] == 0: # empty tensor - return op.Identity(self) - else: - # NOTE: In pytorch, default value of dims is an empty list. - if len(dims) == 0: # Empty sequence - # assert isinstance(shifts, int) - return _aten_roll_shift_no_dim_onnx(self, shifts) - else: - # assert len(shifts) == len(dims), but shifts is a tensor, dims is a list - result = self - for i, shift in enumerate(shifts): - dim = dims[i] - result = _aten_roll_shift_and_dim_onnx(result, shift, dim) - return result - - -@torch_op("aten::roll", trace_only=True, complex=True) -def aten_roll_complex( - self: TTensor, shifts: Sequence[int], dims: Sequence[int] = () -) -> TTensor: - """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor""" - - self_rank = len(self.shape) - if self_rank == 1: - return op.Identity(self) - - if self.shape[0] == 0: # empty tensor - return op.Identity(self) - - self_real = op.Slice(self, [0], [1], axes=[-1]) - self_imag = op.Slice(self, [1], [2], axes=[-1]) - if not dims: - # assert isinstance(shifts, int) - shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts) - shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts) - - result = op.Concat(shift_real, shift_imag, axis=-1) - - else: - # assert len(shifts) == len(dims), but shifts is a tensor, dims is a list - for i, dim in enumerate(dims): - shift = op.Gather(shifts, i, axis=0) - self_real = _aten_roll_shift_and_dim_onnx(self_real, shift, dim) - self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shift, dim) - - result = op.Concat(self_real, self_imag, axis=-1) - return result - - -@torch_op("aten::roll", private=True) -def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor: - neg_1 = op.Constant(value_ints=[-1]) - # flatten the self tensor: from [[A,B],[C,D]] to [A,B,C,D] - self_flatten = op.Reshape(self, neg_1) - # Compute slice length - shift_tensor = op.Reshape(shift, neg_1) - if shift_tensor < 0: - # For [A,B,C,D], if shift is -1, slice_length = -(-1) = 1, means move [A] to the end - slice_length = -shift_tensor - else: - # For [A,B,C,D], if shift is 1, slice_length = 4 - 1 = 3, means move [A,B,C] to the end - # The effect equals to move [D] to the beginning - slice_length = op.Size(self_flatten) - shift_tensor - # Get second part of the tensor, e.g. [A,B,C] - suffix = op.Slice(self_flatten, op.Constant(value_ints=[0]), slice_length) - # Get first part of the tensor, e.g. [D] - prefix = op.Slice(self_flatten, slice_length, op.Reshape(op.Size(self_flatten), neg_1)) - # Concat first+second together, e.g. [D,A,B,C] - result = op.Concat(prefix, suffix, axis=0) - return op.Reshape(result, op.Shape(self)) - - -@torch_op("aten::roll", private=True) -def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: INT64, dim: int) -> TTensor: - neg_1 = op.Constant(value_ints=[-1]) - dim_tensor = op.Reshape(op.Constant(value_int=dim), neg_1) - shift_tensor = op.Reshape(shift, neg_1) - if shift_tensor < 0: - slice_length = -shift_tensor - else: - slice_length = op.Gather(op.Shape(self), dim_tensor, axis=0) - shift_tensor - # from [A,B,C,D] -> [D,A,B,C], [D] is prefix, [A,B,C] is suffix - suffix = op.Slice(self, op.Constant(value_ints=[0]), slice_length, axes=dim_tensor) - prefix = op.Slice(self, slice_length, op.Reshape(op.Size(self), neg_1), axes=dim_tensor) - result = op.Concat(prefix, suffix, axis=dim) - return result + raise NotImplementedError() def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> TensorType: @@ -7683,7 +7595,7 @@ def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: @torch_op("aten::scalar_tensor", trace_only=True) def aten_scalar_tensor( - s: float, + s, dtype: int = FLOAT.dtype, layout: str = "", device: str = "", @@ -7692,8 +7604,7 @@ def aten_scalar_tensor( """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" if dtype == -1: dtype = FLOAT.dtype - # Set trace_only=True because different if branches return different dtypes - # which is not supported in an ONNX function + return common_ops.cast_to(s, dtype=dtype) @@ -7722,18 +7633,7 @@ def aten_scalar_tensor_complex( return result -@torch_op("aten::scalar_tensor", trace_only=True) -def aten_scalar_tensor_sym_number( - s: TensorType, - dtype: int = FLOAT.dtype, - layout: str = "", - device: str = "", - pin_memory: bool = False, -) -> RealType: - """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - if dtype == -1: - dtype = FLOAT.dtype - return common_ops.cast_to(s, dtype=dtype) + @torch_op(("aten::scatter.value", "aten::scatter.src"), trace_only=True) @@ -8105,7 +8005,7 @@ def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) result = op.Softmax(self, axis=dim) - if dtype != -1: + if dtype != -1 and dtype is not None: result = op.Cast(result, to=dtype) if self_is_scalar: # Convert to scalar when input is scalar @@ -8114,21 +8014,6 @@ def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: return result -@torch_op(("aten::softmax.int", "aten::special_softmax"), trace_only=True) -def aten_softmax_no_dtype(self: TFloat, dim: int) -> TFloat: - """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" - - self_is_scalar = len(self.shape) == 0 - if self_is_scalar: - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - result = op.Softmax(self, axis=dim) - if self_is_scalar: - # Convert to scalar when input is scalar - result = op.Squeeze(result) - - return result - - @torch_op("aten::sort", trace_only=True) def aten_sort( self: TReal, dim: int = -1, descending: bool = False, stable: bool = False diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 1a31c9eac8..80be5655f4 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -294,20 +294,16 @@ def aten_binary_cross_entropy_backward( @torch_op("aten::celu", trace_only=True) -def aten_celu(self: FLOAT, alpha: float = 1.0) -> FLOAT: +def aten_celu(self: TFloat, alpha: float = 1.0) -> TFloat: """celu(Tensor self, Scalar alpha=1.0) -> Tensor""" - return op.Celu(self, alpha=alpha) # op.Celu only support float32 + if self.dtype != FLOAT.dtype: + self_upcasted = op.Cast(self, to=FLOAT.dtype) + # op.Celu only support float32 + return op.Cast(op.Celu(self_upcasted, alpha=alpha), to=self.dtype) -@torch_op("aten::celu", trace_only=True) -def aten_celu_type_promoted( - self: TFloatUnlessFloat32, alpha: float = 1.0 -) -> TFloatUnlessFloat32: - """celu(Tensor self, Scalar alpha=1.0) -> Tensor""" - - self_upcasted = op.Cast(self, to=FLOAT.dtype) - return op.CastLike(op.Celu(self_upcasted, alpha=alpha), self) + return op.Celu(self, alpha=alpha) @torch_op("aten::col2im", trace_only=True) @@ -1854,6 +1850,11 @@ def aten_scaled_dot_product_attention( query, key, value, scale, dropout_p ) + if attn_mask.dtype == ir.DataType.BOOL: + return _aten_scaled_dot_product_attention_bool_mask_onnx( + query, key, value, attn_mask, scale, dropout_p + ) + return _aten_scaled_dot_product_attention_float_mask_onnx( query, key, value, attn_mask, scale, dropout_p ) @@ -1921,7 +1922,6 @@ def aten__scaled_dot_product_flash_attention( ) -@torch_op("aten::_scaled_dot_product_efficient_attention", private=True) def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs( query: TFloat, compute_log_sumexp: bool, @@ -2016,64 +2016,6 @@ def aten__scaled_dot_product_efficient_attention( ) -@torch_op("aten::scaled_dot_product_attention", trace_only=True) -def aten_scaled_dot_product_attention_bool_mask( - query: TFloat, - key: TFloat, - value: TFloat, - attn_mask: Optional[BOOL] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - enable_gqa: bool = False, -) -> TFloat: - """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor - - Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - - Equivalent to the PyTorch code:: - scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale - attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask - attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask - attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1) - attn_weight = torch.dropout(attn_weight, dropout_p) - return attn_weight @ V - - where Q, K, V are the query, key, and value tensors, respectively. - L is the target sequence length, S is the source sequence length, and E is the embedding size. - """ - # Use trace_only to handle optional inputs - assert (not is_causal) or (is_causal and attn_mask is None), ( - "is_causal and attn_mask cannot be set at the same time" - ) - assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, ( - "only 4D query, key, and value are supported" - ) - - if scale is None: - scale = _attention_scale(query) - scale = op.CastLike(scale, query) - - if is_causal: - attn_mask = _causal_attention_mask(query, key) - # The causal mask is always float - return _aten_scaled_dot_product_attention_float_mask_onnx( - query, key, value, attn_mask, scale, dropout_p - ) - - if enable_gqa: - key, value = _attention_repeat_kv_for_group_query(query, key, value) - - if attn_mask is None: - return _aten_scaled_dot_product_attention_no_mask_onnx( - query, key, value, scale, dropout_p - ) - - return _aten_scaled_dot_product_attention_bool_mask_onnx( - query, key, value, attn_mask, scale, dropout_p - ) - - def _aten_scaled_dot_product_attention_no_mask_onnx( query: TFloat, key: TFloat, From 7457bbc205c7dcd3b212045f809c5e7f1423f626 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 14:34:51 -0700 Subject: [PATCH 02/23] normal Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 38 ++++++------------- 1 file changed, 11 insertions(+), 27 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index dc87e172af..63efc8ecfe 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6516,16 +6516,7 @@ def aten_norm_except_dim(v: TensorType, pow: int = 2, dim: int = 0) -> TensorTyp raise NotImplementedError() -@torch_op( - ( - "aten::normal.Tensor_float", - "aten::normal.Tensor_Tensor", - "aten::normal.float_Tensor", - "aten::normal.float_float", - "aten::normal_functional", - ), - trace_only=True, -) +@torch_op("aten::normal_functional", trace_only=True) def aten_normal( self: TTensor, mean: float = 0.0, @@ -6554,7 +6545,7 @@ def aten_normal_float_float( return op.Cast(result, to=dtype) -@torch_op("aten::normal.float_Tensor") +@torch_op("aten::normal.float_Tensor", trace_only=True) def aten_normal_float_tensor(mean: FLOAT, std: TFloat) -> TFloat: """normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor""" @@ -6564,7 +6555,7 @@ def aten_normal_float_tensor(mean: FLOAT, std: TFloat) -> TFloat: return op.Add(op.Mul(std, sampled), mean_casted) -@torch_op("aten::normal.Tensor_float") +@torch_op("aten::normal.Tensor_float", trace_only=True) def aten_normal_tensor_float(mean: TFloat, std: FLOAT) -> TFloat: """normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor""" @@ -6573,7 +6564,7 @@ def aten_normal_tensor_float(mean: TFloat, std: FLOAT) -> TFloat: return op.Add(op.Mul(op.CastLike(std, sampled), sampled), mean) -@torch_op("aten::normal.Tensor_Tensor") +@torch_op("aten::normal.Tensor_Tensor", trace_only=True) def aten_normal_tensor_tensor(mean: TFloat, std: TFloat) -> TFloat: """normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor""" @@ -7281,10 +7272,15 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType: raise NotImplementedError() -@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"), trace_only=True) -def aten_remainder(self: TFloat, other: TFloat) -> TFloat: +@torch_op( + ("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"), trace_only=True +) +def aten_remainder(self: TTensor, other: TTensor) -> TTensor: """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" + if self.dtype.is_integer(): + return op.Mod(self, other) + # TODO(justinchuby): Improve fp16 precision by following the logic in # https://github.com/pytorch/pytorch/blob/3a823e46170778cc32783f27596c77d0103084a9/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L264-L277 @@ -7294,15 +7290,6 @@ def aten_remainder(self: TFloat, other: TFloat) -> TFloat: return op.Sub(self, op.Mul(rounded_quotient, other)) -@torch_op( - ("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"), trace_only=True -) -def aten_remainder_int(self: TInt, other: TInt) -> TInt: - """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" - - return op.Mod(self, other) - - def aten_rename(self: TensorType, names: Optional[str]) -> TensorType: """rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)""" @@ -7633,9 +7620,6 @@ def aten_scalar_tensor_complex( return result - - - @torch_op(("aten::scatter.value", "aten::scatter.src"), trace_only=True) def aten_scatter( self: TReal, From 4336e0f685aafb36e1904329c7db5f46c64800ca Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 14:44:29 -0700 Subject: [PATCH 03/23] wip Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 106 +++++------------- 1 file changed, 30 insertions(+), 76 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 63efc8ecfe..e3af69f63c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4022,22 +4022,16 @@ def aten_gru_cell( def aten_gt(self: TReal, other: TReal) -> BOOL: """gt.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.Greater(self, other) + if self.dtype == ir.DataType.BOOL: + # self, other, self > other + # F, F, F + # F, T, F + # T, F, T + # T, T, F + return op.And(self, op.Not(other)) -@torch_op( - ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"), - trace_only=True, -) -def aten_gt_bool(self: BOOL, other: BOOL) -> BOOL: - """gt.Tensor(Tensor self, Tensor other) -> Tensor""" - # self, other, self > other - # F, F, F - # F, T, F - # T, F, T - # T, T, F - - return op.And(self, op.Not(other)) + return op.Greater(self, other) @torch_op("aten::hamming_window", trace_only=True) @@ -5161,23 +5155,15 @@ def aten_lstm_mps_backward( def aten_lt(self: TReal, other: TReal) -> BOOL: """lt.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.Less(self, other) - - -@torch_op( - ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), - trace_only=True, -) -def aten_lt_bool(self: BOOL, other: BOOL) -> BOOL: - """lt.Tensor(Tensor self, Tensor other) -> Tensor""" - - # self, other, self < other - # F, F, F - # F, T, T - # T, F, F - # T, T, F + if self.dtype == ir.DataType.BOOL: + # self, other, self < other + # F, F, F + # F, T, T + # T, F, F + # T, T, F + return op.And(other, op.Not(self)) - return op.And(other, op.Not(self)) + return op.Less(self, other) def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType: @@ -5352,17 +5338,13 @@ def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, I @torch_op("aten::maximum") -def aten_maximum(self: TReal, other: TReal) -> TReal: +def aten_maximum(self: TTensor, other: TTensor) -> TTensor: """maximum(Tensor self, Tensor other) -> Tensor""" - return op.Max(self, other) - - -@torch_op("aten::maximum") -def aten_maximum_bool(self: BOOL, other: BOOL) -> BOOL: - """maximum(Tensor self, Tensor other) -> Tensor""" + if self.dtype == ir.DataType.BOOL: + return op.Or(self, other) - return op.Or(self, other) + return op.Max(self, other) @torch_op("aten::mean") @@ -5418,18 +5400,14 @@ def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, T return result, indices -@torch_op("aten::minimum") -def aten_minimum(self: TReal, other: TReal) -> TReal: +@torch_op("aten::minimum", trace_only=True) +def aten_minimum(self: TTensor, other: TTensor) -> TTensor: """minimum(Tensor self, Tensor other) -> Tensor""" - return op.Min(self, other) - - -@torch_op("aten::minimum") -def aten_minimum_bool(self: BOOL, other: BOOL) -> BOOL: - """minimum(Tensor self, Tensor other) -> Tensor""" + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) - return op.And(self, other) + return op.Min(self, other) def aten_miopen_batch_norm( @@ -5772,23 +5750,13 @@ def aten_msort(self: TensorType) -> TensorType: ("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"), trace_only=True, ) -def aten_mul(self: TReal, other: TReal) -> TReal: +def aten_mul(self: TTensor, other: TTensor) -> TTensor: """mul.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.Mul(self, other) - - -@torch_op( - ("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"), - trace_only=True, -) -def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: - """ONNX Mul doesn't support Boolean, so use And as an equivalent operator.""" - - # TODO(justinchuby): Handle cases where type reconcilation is not enough, - # since different ONNX operators are used based on different data types. + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) - return op.And(self, other) + return op.Mul(self, other) @torch_op( @@ -6030,7 +5998,6 @@ def aten_native_batch_norm( return norm, input_mean, input_rstd -@torch_op("aten::native_batch_norm", private=True) def _aten_native_batch_norm_training_onnx( input: TFloat, weight: TFloat, @@ -6082,7 +6049,6 @@ def _aten_native_batch_norm_training_onnx( return norm, mean, rstd, running_mean, new_running_var -@torch_op("aten::native_batch_norm", private=True) def _aten_native_batch_norm_inference_onnx( input: TFloat, weight: TFloat, @@ -6252,22 +6218,10 @@ def aten_native_group_norm( if bias is None: # Set to 0.0 as default, the shape is Channel size bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)) - # Accoding to Torch, return rstd instead of var - norm, mean, rstd = _aten_native_group_norm_onnx(input, weight, bias, group, eps) - return norm, mean, rstd - - -@torch_op("aten::native_group_norm", private=True) -def _aten_native_group_norm_onnx( - input: TFloat, - weight: TFloat, - bias: TFloat, - group: INT64, - eps: float, -) -> Tuple[TFloat, TFloat, TFloat]: # Because onnx.GroupNorm() need size=group for weight and bias # But the torch's aten function's input need size=channel, the size mismatched # So we have to use onnx.InstanceNorm() to simulate + # This implementation should be simplified after opset 21 neg_1 = op.Constant(value_ints=[-1]) # Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter group_tensor = op.Reshape(group, neg_1) From 93d6647c13838fe576b76e9945aa67ba7c131b48 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 14:47:09 -0700 Subject: [PATCH 04/23] update Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e3af69f63c..f0f0993f4d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5073,29 +5073,23 @@ def aten_logical_xor(self: TTensor, other: TTensor) -> BOOL: return op.Xor(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) -@torch_op("aten::logit", private=True) -def _aten_logit_onnx(self: TFloat) -> TFloat: - return op.Log(op.Div(self, op.Sub(1.0, self))) +@torch_op("aten::logit", trace_only=True) +def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat: + """logit(Tensor self, float? eps=None) -> Tensor""" + one = ir.tensor(1, dtype=self.dtype) + if eps is None: + return op.Log(op.Div(self, op.Sub(one, self))) + + one_minus_eps = ir.tensor(1 - eps, dtype=self.dtype) + eps = ir.tensor(eps, dtype=self.dtype) -@torch_op("aten::logit", private=True) -def _aten_logit_clamp_onnx(self: TFloat, eps: float) -> TFloat: - eps = op.CastLike(eps, self) - one = op.CastLike(1.0, self) - temporary_self = op.Where(self <= one - eps, self, one - eps) + temporary_self = op.Where(self <= one_minus_eps, self, one_minus_eps) z = op.Where(temporary_self < eps, eps, temporary_self) return op.Log(op.Div(z, op.Sub(one, z))) -@torch_op("aten::logit", trace_only=True) -def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat: - """logit(Tensor self, float? eps=None) -> Tensor""" - if eps is None: - return _aten_logit_onnx(self) - return _aten_logit_clamp_onnx(self, eps) - - def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> TensorType: """logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" From 2b0861c83000eff76d2401314393beecfb348869 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 14:48:46 -0700 Subject: [PATCH 05/23] update Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 54 +++++++------------ 1 file changed, 20 insertions(+), 34 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f0f0993f4d..d349453fa5 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3871,26 +3871,18 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType: ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"), trace_only=True, ) -def aten_ge(self: TReal, other: TReal) -> BOOL: +def aten_ge(self: TTensor, other: TTensor) -> BOOL: """ge.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.GreaterOrEqual(self, other) - - -@torch_op( - ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"), - trace_only=True, -) -def aten_ge_bool(self: BOOL, other: BOOL) -> BOOL: - """ge.Tensor(Tensor self, Tensor other) -> Tensor""" - - # self, other, self >= other - # F, F, T - # F, T, F - # T, F, T - # T, T, T + if self.dtype == ir.DataType.BOOL: + # self, other, self >= other + # F, F, T + # F, T, F + # T, F, T + # T, T, T + return op.Or(self, op.Not(other)) - return op.Or(self, op.Not(other)) + return op.GreaterOrEqual(self, other) def aten_geqrf(self: TensorType) -> tuple[TensorType, TensorType]: @@ -4019,7 +4011,7 @@ def aten_gru_cell( ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"), trace_only=True, ) -def aten_gt(self: TReal, other: TReal) -> BOOL: +def aten_gt(self: TTensor, other: TTensor) -> BOOL: """gt.Tensor(Tensor self, Tensor other) -> Tensor""" if self.dtype == ir.DataType.BOOL: @@ -4852,26 +4844,20 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType: ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"), trace_only=True, ) -def aten_le(self: TReal, other: TReal) -> BOOL: +def aten_le(self: TTensor, other: TTensor) -> BOOL: """le.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.LessOrEqual(self, other) - + if self.dtype == ir.DataType.BOOL: + # self, other, self <= other + # F, F, T + # F, T, T + # T, F, F + # T, T, T -@torch_op( - ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"), - trace_only=True, -) -def aten_le_bool(self: BOOL, other: BOOL) -> BOOL: - """le.Tensor(Tensor self, Tensor other) -> Tensor""" + return op.Or(other, op.Not(self)) - # self, other, self <= other - # F, F, T - # F, T, T - # T, F, F - # T, T, T + return op.LessOrEqual(self, other) - return op.Or(other, op.Not(self)) @torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar")) @@ -5146,7 +5132,7 @@ def aten_lstm_mps_backward( ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), trace_only=True, ) -def aten_lt(self: TReal, other: TReal) -> BOOL: +def aten_lt(self: TTensor, other: TTensor) -> BOOL: """lt.Tensor(Tensor self, Tensor other) -> Tensor""" if self.dtype == ir.DataType.BOOL: From d8124b9588ab516f59586e290d9df98c938280c6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 15:01:01 -0700 Subject: [PATCH 06/23] bit shift Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 347 ++++-------------- 1 file changed, 74 insertions(+), 273 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d349453fa5..88f9892d0b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1265,78 +1265,30 @@ def aten_bitwise_and(self: TTensor, other: TTensor) -> TTensor: ), trace_only=True, ) -def aten_bitwise_left_shift_int16(self: INT16, other: INT16) -> INT16: +def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt: """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" # assert other >= 0 - self = op.Cast(self, to=UINT16.dtype) - other = op.Cast(other, to=UINT16.dtype) - - result = op.BitShift(self, other, direction="LEFT") - - return op.Cast(result, to=INT16.dtype) - - -@torch_op( - ( - "aten::bitwise_left_shift.Tensor", - "aten::bitwise_left_shift.Tensor_Scalar", - "aten::bitwise_left_shift.Scalar_Tensor", - "_operator::__lshift__", - "aten::__lshift__.Scalar", - ), - trace_only=True, -) -def aten_bitwise_left_shift_int32(self: INT32, other: INT32) -> INT32: - """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - # assert other >= 0 - self = op.Cast(self, to=UINT32.dtype) - other = op.Cast(other, to=UINT32.dtype) - - result = op.BitShift(self, other, direction="LEFT") - - return op.Cast(result, to=INT32.dtype) - - -@torch_op( - ( - "aten::bitwise_left_shift.Tensor", - "aten::bitwise_left_shift.Tensor_Scalar", - "aten::bitwise_left_shift.Scalar_Tensor", - "_operator::__lshift__", - "aten::__lshift__.Scalar", - ), - trace_only=True, -) -def aten_bitwise_left_shift_int64(self: INT64, other: INT64) -> INT64: - """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - # assert other >= 0 - self = op.Cast(self, to=UINT64.dtype) - other = op.Cast(other, to=UINT64.dtype) - - result = op.BitShift(self, other, direction="LEFT") - - return op.Cast(result, to=INT64.dtype) + if self.dtype.bitwidth == 8: + unsigned_dtype = ir.DataType.UINT8 + signed_dtype = ir.DataType.INT8 + elif self.dtype.bitwidth == 16: + unsigned_dtype = ir.DataType.UINT16 + signed_dtype = ir.DataType.INT16 + elif self.dtype.bitwidth == 32: + unsigned_dtype = ir.DataType.UINT32 + signed_dtype = ir.DataType.INT32 + elif self.dtype.bitwidth == 64: + unsigned_dtype = ir.DataType.UINT64 + signed_dtype = ir.DataType.INT64 + else: + raise NotImplementedError(f"Not implemented for type {self.dtype}") - -@torch_op( - ( - "aten::bitwise_left_shift.Tensor", - "aten::bitwise_left_shift.Tensor_Scalar", - "aten::bitwise_left_shift.Scalar_Tensor", - "_operator::__lshift__", - "aten::__lshift__.Scalar", - ), - trace_only=True, -) -def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8: - """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - # assert other >= 0 - self = op.Cast(self, to=UINT8.dtype) - other = op.Cast(other, to=UINT8.dtype) + self = op.Cast(self, to=unsigned_dtype) + other = op.Cast(other, to=unsigned_dtype) result = op.BitShift(self, other, direction="LEFT") - return op.Cast(result, to=INT8.dtype) + return op.Cast(result, to=signed_dtype) @torch_op("aten::bitwise_not", trace_only=True) @@ -1380,113 +1332,34 @@ def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor: "aten::__rshift__.Scalar", ) ) -def aten_bitwise_right_shift_int16(self: INT16, other: INT16) -> INT16: - """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - negative = op.Less(self, 0) - self = op.Cast(self, to=UINT16.dtype) - other = op.Cast(other, to=UINT16.dtype) - - # Simulate arithmetic shift using logical shift - # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - op.Cast(op.Constant(value_int=0xFFFF), to=UINT16.dtype), other, direction="RIGHT" - ) - mask = op.BitwiseNot(mask) - # Do logical shift - shifted = op.BitShift(self, other, direction="RIGHT") - # Compute the arithmetic shifted value assuming the sign bit was set - negative_shifted = op.BitwiseOr(shifted, mask) - # Choose the shifted value based on the sign bit - return op.Where( - negative, op.Cast(negative_shifted, to=INT16.dtype), op.Cast(shifted, to=INT16.dtype) - ) - - -@torch_op( - ( - "aten::bitwise_right_shift.Tensor", - "aten::bitwise_right_shift.Tensor_Scalar", - "aten::bitwise_right_shift.Scalar_Tensor", - "_operator::__rshift__", - "aten::__rshift__.Scalar", - ) -) -def aten_bitwise_right_shift_int32(self: INT32, other: INT32) -> INT32: - """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - negative = op.Less(self, 0) - self = op.Cast(self, to=UINT32.dtype) - other = op.Cast(other, to=UINT32.dtype) - - # Simulate arithmetic shift using logical shift - # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - op.Cast(op.Constant(value_int=0xFFFFFFFF), to=UINT32.dtype), other, direction="RIGHT" - ) - mask = op.BitwiseNot(mask) - # Do logical shift - shifted = op.BitShift(self, other, direction="RIGHT") - # Compute the arithmetic shifted value assuming the sign bit was set - negative_shifted = op.BitwiseOr(shifted, mask) - # Choose the shifted value based on the sign bit - return op.Where( - negative, op.Cast(negative_shifted, to=INT32.dtype), op.Cast(shifted, to=INT32.dtype) - ) - - -@torch_op( - ( - "aten::bitwise_right_shift.Tensor", - "aten::bitwise_right_shift.Tensor_Scalar", - "aten::bitwise_right_shift.Scalar_Tensor", - "_operator::__rshift__", - "aten::__rshift__.Scalar", - ) -) -def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64: +def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt: """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - negative = op.Less(self, 0) - self = op.Cast(self, to=UINT64.dtype) - other = op.Cast(other, to=UINT64.dtype) - - # Simulate arithmetic shift using logical shift - # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - # 0xFFFFFFFFFFFFFFFF - op.Cast(op.Constant(value_int=-1), to=UINT64.dtype), - other, - direction="RIGHT", - ) - mask = op.BitwiseNot(mask) - # Do logical shift - shifted = op.BitShift(self, other, direction="RIGHT") - # Compute the arithmetic shifted value assuming the sign bit was set - negative_shifted = op.BitwiseOr(shifted, mask) - # Choose the shifted value based on the sign bit - return op.Where( - negative, op.Cast(negative_shifted, to=INT64.dtype), op.Cast(shifted, to=INT64.dtype) - ) - + if self.dtype.bitwidth == 8: + unsigned_dtype = ir.DataType.UINT8 + signed_dtype = ir.DataType.INT8 + mask = ir.tensor(0xFF, dtype=unsigned_dtype) + elif self.dtype.bitwidth == 16: + unsigned_dtype = ir.DataType.UINT16 + signed_dtype = ir.DataType.INT16 + mask = ir.tensor(0xFFFF, dtype=unsigned_dtype) + elif self.dtype.bitwidth == 32: + unsigned_dtype = ir.DataType.UINT32 + signed_dtype = ir.DataType.INT32 + mask = ir.tensor(0xFFFFFFFF, dtype=unsigned_dtype) + elif self.dtype.bitwidth == 64: + unsigned_dtype = ir.DataType.UINT64 + signed_dtype = ir.DataType.INT64 + mask = ir.tensor(-1, dtype=unsigned_dtype) # 0xFFFFFFFFFFFFFFFF + else: + raise NotImplementedError(f"Not implemented for type {self.dtype}") -@torch_op( - ( - "aten::bitwise_right_shift.Tensor", - "aten::bitwise_right_shift.Tensor_Scalar", - "aten::bitwise_right_shift.Scalar_Tensor", - "_operator::__rshift__", - "aten::__rshift__.Scalar", - ) -) -def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: - """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" negative = op.Less(self, 0) - self = op.Cast(self, to=UINT8.dtype) - other = op.Cast(other, to=UINT8.dtype) + self = op.Cast(self, to=unsigned_dtype) + other = op.Cast(other, to=unsigned_dtype) # Simulate arithmetic shift using logical shift # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - op.Cast(op.Constant(value_int=0xFF), to=UINT8.dtype), other, direction="RIGHT" - ) + mask = op.BitShift(mask, other, direction="RIGHT") mask = op.BitwiseNot(mask) # Do logical shift shifted = op.BitShift(self, other, direction="RIGHT") @@ -1494,7 +1367,7 @@ def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: negative_shifted = op.BitwiseOr(shifted, mask) # Choose the shifted value based on the sign bit return op.Where( - negative, op.Cast(negative_shifted, to=INT8.dtype), op.Cast(shifted, to=INT8.dtype) + negative, op.Cast(negative_shifted, to=signed_dtype), op.Cast(shifted, to=signed_dtype) ) @@ -2156,7 +2029,6 @@ def aten_convolution( return result -@torch_op("aten::convolution", private=True, trace_only=True) def _aten_convolution_onnx( input: TFloat, weight: TFloat, @@ -2628,80 +2500,10 @@ def aten_diagflat(self: TensorType, offset: int = 0) -> TensorType: @torch_op(("aten::diagonal", "aten::diagonal_copy"), trace_only=True) -def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) -> TReal: +def aten_diagonal(self: TTensor, offset: int = 0, dim1: int = 0, dim2: int = 1) -> TTensor: """diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)""" - # perm is used to transpose the tensor to make dim1 and dim2 as the last 2 dims - # [0,1,2] -> [2,0,1] when dim1=0 and dim2=1 - # [0,1,2] -> [1,0,2] when dim1=0 and dim2=2 - # [0,1,2] -> [0,1,2] when dim1=1 and dim2=2 - if dim1 < 0: - dim1 = dim1 + len(self.shape) - if dim2 < 0: - dim2 = dim2 + len(self.shape) - - self_rank = len(self.shape) - perm = list(range(self_rank)) - perm.remove(dim1) - perm.remove(dim2) - perm.append(dim1) - perm.append(dim2) - - # If rank=2, then axes=[0]; if rank=3, then axes=[1] - # This is because computing diagonal sum is on dim2 after transpose by perm - axes = [self_rank - 2] - - neg_1 = op.Constant(value_ints=[-1]) - dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row - dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col - mask_shape = op.Concat(dim1_size, dim2_size, axis=0) - mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) - mask = op.CastLike(mask, self) - self_t = op.Transpose(self, perm=perm) - result = op.Mul(self_t, mask) - result = op.ReduceSum(result, keepdims=False, axes=axes) - # min(row, col) - min_dim_size = op.Min(dim1_size, dim2_size) - # take 2 tensors as example: - # one is 3x5 in size, min_dim_size = 3, dim1_size = 3 - # the other is 5x3 in size, min_dim_size = 3, dim1_size = 5 - # 3 rows x 5 cols 5 rows x 3 cols - # offset diagonal offset diagonal - # ---------------- ---------------- - # -4 0 -6 0 - # -3 0 -5 0 - # -2 1 -4 1 - # -1 2 -3 2 - # 0 3 -2 3 - # 1 3 -1 3 - # 2 3 0 3 - # 3 2 1 2 - # 4 1 2 1 - # 5 0 3 0 - # 6 0 4 0 - - # From above table, we can get the logic below - offset_val = op.Constant(value_ints=[offset]) - if offset < 0: - # row + offset - length = op.Add(dim1_size, offset_val) - start = op.Constant(value_ints=[0]) - else: # offset >= 0 - # col - offset - length = op.Sub(dim2_size, offset_val) - start = offset_val - - # max(min(length, min(row, col)), 0) - length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) - end = op.Add(start, length) - result = op.Slice(result, start, end, axes=axes) - - return result - - -@torch_op("aten::diagonal", trace_only=True) -def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1) -> BOOL: - """diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)""" + is_bool = self.dtype == BOOL.dtype # perm is used to transpose the tensor to make dim1 and dim2 as the last 2 dims # [0,1,2] -> [2,0,1] when dim1=0 and dim2=1 @@ -2728,10 +2530,16 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1 dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) - self_int = op.Cast(self, to=INT64.dtype) - mask_int = op.Cast(mask, to=INT64.dtype) - self_int_t = op.Transpose(self_int, perm=perm) - result = op.Mul(self_int_t, mask_int) + + if is_bool: + self_int = op.Cast(self, to=INT64.dtype) + mask_int = op.Cast(mask, to=INT64.dtype) + self_int_t = op.Transpose(self_int, perm=perm) + result = op.Mul(self_int_t, mask_int) + else: + mask = op.CastLike(mask, self) + self_t = op.Transpose(self, perm=perm) + result = op.Mul(self_t, mask) result = op.ReduceSum(result, keepdims=False, axes=axes) # min(row, col) min_dim_size = op.Min(dim1_size, dim2_size) @@ -2768,7 +2576,9 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1 length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) end = op.Add(start, length) result = op.Slice(result, start, end, axes=axes) - result = op.Cast(result, to=BOOL.dtype) + + if is_bool: + result = op.Cast(result, to=BOOL.dtype) return result @@ -2879,45 +2689,37 @@ def aten_div_complex(self: TFloat, other: TFloat) -> TFloat: @torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) -def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: Optional[str] = None) -> TFloat: +def aten_div_mode(self: TReal, other: TReal, rounding_mode: Optional[str] = None) -> TReal: """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor""" assert rounding_mode in {"trunc", "floor", None} - if rounding_mode == "trunc": - # Rounds the results of the division towards zero. - # Equivalent to C-style integer division - return aten_trunc(op.Div(self, other)) - if rounding_mode == "floor": - return op.Floor(op.Div(self, other)) - - return op.Div(self, other) + if self.dtype.is_integer(): + quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype)) + if rounding_mode == "trunc": + # Rounds the results of the division towards zero. + # Equivalent to C-style integer division + result = aten_trunc(quotient) + return op.CastLike(result, self) + if rounding_mode == "floor": + result = op.Floor(quotient) + return op.CastLike(result, self) -@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) -def aten_div_mode_int( - self: TInt, other: TInt, rounding_mode: Optional[str] = None -) -> TensorType: - """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor - - Variant for integer inputs. - """ - assert rounding_mode in {"trunc", "floor", None} + assert rounding_mode is None + # When rounding_mode is None, the return type is float32 + return quotient - quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype)) + # Float inputs if rounding_mode == "trunc": # Rounds the results of the division towards zero. # Equivalent to C-style integer division - result = aten_trunc(quotient) - return op.CastLike(result, self) + return aten_trunc(op.Div(self, other)) if rounding_mode == "floor": - result = op.Floor(quotient) - return op.CastLike(result, self) + return op.Floor(op.Div(self, other)) - assert rounding_mode is None - # When rounding_mode is None, the return type is float32 - return quotient + return op.Div(self, other) @torch_op("aten::dot", trace_only=True) @@ -4859,7 +4661,6 @@ def aten_le(self: TTensor, other: TTensor) -> BOOL: return op.LessOrEqual(self, other) - @torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar")) def aten_lerp(self: TTensor, end: TTensor, weight: TTensor) -> TTensor: """lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor""" From aebd259bdfbd2c4a9b178b83094483db50e069b7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 15:06:16 -0700 Subject: [PATCH 07/23] more Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 53 +++++-------------- 1 file changed, 12 insertions(+), 41 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 88f9892d0b..eb53e9220e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -77,13 +77,11 @@ def aten__local_scalar_dense(self: TensorType) -> TensorType: @torch_op("aten::_log_softmax", trace_only=True) -def aten__log_softmax_half( - self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool -) -> FLOAT: +def aten__log_softmax(self: TFloat, dim: int, half_to_float: bool) -> TFloatHighPrecision: """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" self_is_scalar = len(self.shape) == 0 - if half_to_float: + if half_to_float and self.dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16}: self = op.Cast(self, to=FLOAT.dtype) if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) @@ -93,44 +91,23 @@ def aten__log_softmax_half( return result -@torch_op("aten::_log_softmax", trace_only=True) -def aten__log_softmax( - self: TFloatHighPrecision, - dim: int, - half_to_float: bool, -) -> TFloatHighPrecision: - """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" +@torch_op("aten::_softmax", trace_only=True) +def aten__softmax(self: TFloat, dim: int, half_to_float: bool) -> TFloatHighPrecision: + """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" self_is_scalar = len(self.shape) == 0 + + if half_to_float and self.dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16}: + self = op.Cast(self, to=FLOAT.dtype) + if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - result = op.LogSoftmax(self, axis=dim) + result = op.Softmax(self, axis=dim) if self_is_scalar: + # Convert to scalar when input is scalar result = op.Squeeze(result) - return result - -@torch_op("aten::_softmax", trace_only=True) -def aten__softmax_half(self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool) -> FLOAT: - """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" - - # trace_only because we need to cast conditionally based on half_to_float - if half_to_float: - self = op.Cast(self, to=FLOAT.dtype) - - return aten_softmax_no_dtype(self, dim) - - -@torch_op("aten::_softmax", trace_only=True) -def aten__softmax( - self: TFloatHighPrecision, dim: int, half_to_float: bool -) -> TFloatHighPrecision: - """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" - - # trace_only to reuse aten_softmax_no_dtype - - del half_to_float # Unused - return aten_softmax_no_dtype(self, dim) + return result @torch_op(("aten::abs", "_operator::abs"), trace_only=True) @@ -380,7 +357,6 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) return self -@torch_op("aten::all.dims", trace_only=True) def _aten_all_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: """all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor""" @@ -499,7 +475,6 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) return self -@torch_op("aten::any.dims", trace_only=True) def _aten_any_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: if len(self.shape) == 0: result = op.Cast(self, to=BOOL.dtype) @@ -739,7 +714,6 @@ def aten_argmax( return result -@torch_op("aten::argmax", private=True, trace_only=True) def _aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -752,7 +726,6 @@ def _aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: return result -@torch_op("aten::argmax", private=True, trace_only=True) def _aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -780,7 +753,6 @@ def aten_argmin( return result -@torch_op("aten::argmin", private=True, trace_only=True) def _aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -793,7 +765,6 @@ def _aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: return result -@torch_op("aten::argmin", private=True, trace_only=True) def _aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" From 7ea409120821a9b609b2b50736e83be841ced830 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 15:08:28 -0700 Subject: [PATCH 08/23] update Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index eb53e9220e..c2cb6544ce 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -18,21 +18,16 @@ import torch from onnxscript import ( - BFLOAT16, BOOL, COMPLEX64, COMPLEX128, DOUBLE, FLOAT, - FLOAT16, INT8, INT16, INT32, INT64, UINT8, - UINT16, - UINT32, - UINT64, graph, ir, ) @@ -1301,7 +1296,8 @@ def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor: "aten::bitwise_right_shift.Scalar_Tensor", "_operator::__rshift__", "aten::__rshift__.Scalar", - ) + ), + trace_only=True, ) def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt: """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" @@ -5089,7 +5085,7 @@ def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, I return result, indices -@torch_op("aten::maximum") +@torch_op("aten::maximum", trace_only=True) def aten_maximum(self: TTensor, other: TTensor) -> TTensor: """maximum(Tensor self, Tensor other) -> Tensor""" @@ -5131,7 +5127,7 @@ def aten_meshgrid(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::min") +@torch_op("aten::min", trace_only=True) def aten_min(self: TReal) -> TReal: """min(Tensor self) -> Tensor""" From fb161a93758f2092f149d13bc607939f574b351d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 15:19:13 -0700 Subject: [PATCH 09/23] Clean up tests Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops_test_data.py | 558 +++--------------- 1 file changed, 78 insertions(+), 480 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 92495d201a..021c92e6b5 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -459,40 +459,9 @@ def _where_input_wrangler( fft_ops.aten__fft_r2c, tolerance={torch.float64: (2e-6, 2e-6), torch.float32: (3e-2, 3e-4)}, ), - TorchLibOpInfo( - "ops.aten._local_scalar_dense", - core_ops.aten__local_scalar_dense, - ), + TorchLibOpInfo("ops.aten._local_scalar_dense", core_ops.aten__local_scalar_dense), TorchLibOpInfo("ops.aten._log_softmax", core_ops.aten__log_softmax), - TorchLibOpInfo( - "ops.aten._log_softmax_half", - core_ops.aten__log_softmax_half, - tolerance={torch.float16: (1e-3, 1e-3)}, - ) - .xfail( - reason="PyTorch does not implement _log_softmax for float16 on CPU", - dtypes=(torch.float16,), - enabled_if=version_utils.torch_older_than("2.2"), - ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.17"), - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - test_class_name="TestOutputConsistencyFullGraph", - ), TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax), - TorchLibOpInfo("ops.aten._softmax_half", core_ops.aten__softmax_half) - .xfail( - reason="PyTorch does not implement _softmax for float16 on CPU", - dtypes=(torch.float16,), - enabled_if=version_utils.torch_older_than("2.2"), - ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.17"), - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - test_class_name="TestOutputConsistencyFullGraph", - ), TorchLibOpInfo("all_dim", core_ops.aten_all_dim).skip( matcher=lambda sample: not (len(sample.kwargs) > 0) or isinstance(sample.kwargs.get("dim"), tuple), @@ -503,10 +472,7 @@ def _where_input_wrangler( reason="this overload requires dim to be a tuple", ), TorchLibOpInfo("allclose", core_ops.aten_allclose), - TorchLibOpInfo( - "all", - core_ops.aten_all, - ).skip( + TorchLibOpInfo("all", core_ops.aten_all).skip( matcher=lambda sample: len(sample.kwargs) != 0, reason="this Aten overload only support one tensor as input by design", ), @@ -541,32 +507,14 @@ def _where_input_wrangler( reason="zero sized inputs cannot be compared", ), TorchLibOpInfo("addmv", core_ops.aten_addmv, tolerance={torch.float16: (2e-3, 2e-2)}), - TorchLibOpInfo( - "addr", - core_ops.aten_addr, - tolerance={torch.float16: (3e-3, 4e-3)}, - ), - TorchLibOpInfo( - "amax", - core_ops.aten_amax, - input_wrangler=_amin_amax_input_wrangler, - ), - TorchLibOpInfo( - "amin", - core_ops.aten_amin, - input_wrangler=_amin_amax_input_wrangler, - ), - TorchLibOpInfo( - "any", - core_ops.aten_any, - ).skip( + TorchLibOpInfo("addr", core_ops.aten_addr, tolerance={torch.float16: (3e-3, 4e-3)}), + TorchLibOpInfo("amax", core_ops.aten_amax, input_wrangler=_amin_amax_input_wrangler), + TorchLibOpInfo("amin", core_ops.aten_amin, input_wrangler=_amin_amax_input_wrangler), + TorchLibOpInfo("any", core_ops.aten_any).skip( matcher=lambda sample: len(sample.kwargs) != 0, reason="this Aten overload only support one tensor as input by design", ), - TorchLibOpInfo( - "any_dim", - core_ops.aten_any_dim, - ).skip( + TorchLibOpInfo("any_dim", core_ops.aten_any_dim).skip( matcher=lambda sample: not (len(sample.kwargs) > 0) or isinstance(sample.kwargs.get("dim"), tuple), reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer", @@ -584,76 +532,46 @@ def _where_input_wrangler( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo( - "atleast_1d_Sequence", - core_ops.aten_atleast_1d_sequence, - ) + TorchLibOpInfo("atleast_1d_Sequence", core_ops.aten_atleast_1d_sequence) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason=( - "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." - "https://github.com/microsoft/onnxscript/issues/960" - ), - ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ), + ) ), TorchLibOpInfo("atleast_2d", core_ops.aten_atleast_2d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo( - "atleast_2d_Sequence", - core_ops.aten_atleast_2d_sequence, - ) + TorchLibOpInfo("atleast_2d_Sequence", core_ops.aten_atleast_2d_sequence) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason=( - "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." - "https://github.com/microsoft/onnxscript/issues/960" - ), - ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ), + ) ), TorchLibOpInfo("atleast_3d", core_ops.aten_atleast_3d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo( - "atleast_3d_Sequence", - core_ops.aten_atleast_3d_sequence, - ) + TorchLibOpInfo("atleast_3d_Sequence", core_ops.aten_atleast_3d_sequence) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason=( - "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." - "https://github.com/microsoft/onnxscript/issues/960" - ), - ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ), + ) ), TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm, tolerance={torch.float16: (1e-3, 1e-2)}), TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True), @@ -668,16 +586,10 @@ def _where_input_wrangler( ), TorchLibOpInfo("ops.aten.bernoulli.p_deterministic", core_ops.aten_bernoulli_p), TorchLibOpInfo("bitwise_and", core_ops.aten_bitwise_and), - TorchLibOpInfo("bitwise_left_shift_int16", core_ops.aten_bitwise_left_shift_int16), - TorchLibOpInfo("bitwise_left_shift_int32", core_ops.aten_bitwise_left_shift_int32), - TorchLibOpInfo("bitwise_left_shift_int64", core_ops.aten_bitwise_left_shift_int64), - TorchLibOpInfo("bitwise_left_shift_int8", core_ops.aten_bitwise_left_shift_int8), + TorchLibOpInfo("bitwise_left_shift", core_ops.aten_bitwise_left_shift), TorchLibOpInfo("bitwise_not", core_ops.aten_bitwise_not), TorchLibOpInfo("bitwise_or", core_ops.aten_bitwise_or), - TorchLibOpInfo("bitwise_right_shift_int16", core_ops.aten_bitwise_right_shift_int16), - TorchLibOpInfo("bitwise_right_shift_int32", core_ops.aten_bitwise_right_shift_int32), - TorchLibOpInfo("bitwise_right_shift_int64", core_ops.aten_bitwise_right_shift_int64), - TorchLibOpInfo("bitwise_right_shift_int8", core_ops.aten_bitwise_right_shift_int8), + TorchLibOpInfo("bitwise_right_shift", core_ops.aten_bitwise_right_shift), TorchLibOpInfo("bitwise_xor", core_ops.aten_bitwise_xor), TorchLibOpInfo("ops.aten.blackman_window", core_ops.aten_blackman_window), TorchLibOpInfo("bmm", core_ops.aten_bmm), @@ -695,10 +607,7 @@ def _where_input_wrangler( reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), TorchLibOpInfo("ceil", core_ops.aten_ceil), - TorchLibOpInfo("chunk", core_ops.aten_chunk).skip( - enabled_if=version_utils.torch_older_than("2.7"), - reason="Test for chunk is not configured for torch<2.7", - ), + TorchLibOpInfo("chunk", core_ops.aten_chunk), TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max_tensor).skip( reason="Size 0 inputs are not handled by design", matcher=lambda sample: sample.input.numel() == 0, @@ -734,7 +643,6 @@ def _where_input_wrangler( TorchLibOpInfo("deg2rad", core_ops.aten_deg2rad), # TorchLibOpInfo("detach", core_ops.aten_detach), # detach is not in OP-TEST-DB TorchLibOpInfo("diagonal", core_ops.aten_diagonal), - TorchLibOpInfo("diagonal_bool", core_ops.aten_diagonal_bool), TorchLibOpInfo("div", core_ops.aten_div).skip( matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, reason="this variation does not take the rounding_mode argument", @@ -752,7 +660,6 @@ def _where_input_wrangler( # Numbers match sometimes but not other times reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990", ), - TorchLibOpInfo("div_mode_int", core_ops.aten_div_mode_int), TorchLibOpInfo("dot", core_ops.aten_dot), TorchLibOpInfo( "empty", @@ -762,8 +669,7 @@ def _where_input_wrangler( ), TorchLibOpInfo("einsum", core_ops.aten_einsum, input_wrangler=_einsum_input_wrangler) .xfail( - reason="fixme: PyTorch produces int64 output with int32 input", - dtypes=(torch.int32,), + reason="fixme: PyTorch produces int64 output with int32 input", dtypes=(torch.int32,) ) .xfail( reason="fixme: ONNX shape inference fails: https://github.com/onnx/onnx/issues/5739", @@ -797,21 +703,15 @@ def _where_input_wrangler( TorchLibOpInfo("fmod", core_ops.aten_fmod), TorchLibOpInfo("frac", core_ops.aten_frac), TorchLibOpInfo("full", core_ops.aten_full), - TorchLibOpInfo( - "full_like", - core_ops.aten_full_like, - ).skip( - enabled_if=ops_test_common.IS_MACOS, - reason="fixme: memory allocation issue on CI", + TorchLibOpInfo("full_like", core_ops.aten_full_like).skip( + enabled_if=ops_test_common.IS_MACOS, reason="fixme: memory allocation issue on CI" ), TorchLibOpInfo("gather", core_ops.aten_gather).skip( matcher=lambda sample: sample.input.numel() == 0 or sample.args[1].numel() == 0, reason="fixme: ORT does not support empty tensors as input", ), TorchLibOpInfo("ge", core_ops.aten_ge), - TorchLibOpInfo("ge_bool", core_ops.aten_ge_bool), TorchLibOpInfo("gt", core_ops.aten_gt), - TorchLibOpInfo("gt_bool", core_ops.aten_gt_bool), # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index), @@ -825,9 +725,7 @@ def _where_input_wrangler( reason="this Aten overload only supports tensor(bool) as indices", ), TorchLibOpInfo( - "index_put", - core_ops.aten_index_put, - input_wrangler=_index_put_input_wrangler, + "index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler ) .skip( matcher=lambda sample: sample.args[0][0].dtype != torch.int64, @@ -862,25 +760,13 @@ def _where_input_wrangler( "linspace", core_ops.aten_linspace, tolerance={torch.float16: (2e-2, 2e-3)}, - ) - .xfail( - dtypes=(torch.int64, torch.int32), - reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", - ) - .xfail( - variant_name="tensor_overload", + ).xfail( dtypes=(torch.int64, torch.int32), reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", - enabled_if=not version_utils.torch_older_than("2.2"), ), TorchLibOpInfo("log", core_ops.aten_log), TorchLibOpInfo("le", core_ops.aten_le), - TorchLibOpInfo("le_bool", core_ops.aten_le_bool), - TorchLibOpInfo( - "lerp", - core_ops.aten_lerp, - tolerance={torch.float16: (2e-3, 2e-1)}, - ), + TorchLibOpInfo("lerp", core_ops.aten_lerp, tolerance={torch.float16: (2e-3, 2e-1)}), TorchLibOpInfo("log10", core_ops.aten_log10), TorchLibOpInfo("log1p", core_ops.aten_log1p), TorchLibOpInfo( @@ -919,7 +805,6 @@ def _where_input_wrangler( TorchLibOpInfo("logdet", core_ops.aten_logdet), TorchLibOpInfo("logsumexp", core_ops.aten_logsumexp), TorchLibOpInfo("lt", core_ops.aten_lt), - TorchLibOpInfo("lt_bool", core_ops.aten_lt_bool), TorchLibOpInfo("masked_fill", core_ops.aten_masked_fill).xfail( dtypes=(torch.bool,), reason="fixme: ORT does not have an implementation for Where with bool inputs.", @@ -935,19 +820,12 @@ def _where_input_wrangler( reason="values of matmul of [m, 0] and [0, n] matrices are undefined", ), TorchLibOpInfo("maximum", core_ops.aten_maximum), - TorchLibOpInfo("maximum_bool", core_ops.aten_maximum_bool), - TorchLibOpInfo( - "mean", - core_ops.aten_mean, - input_wrangler=_mean_input_wrangler, - ).skip( + TorchLibOpInfo("mean", core_ops.aten_mean, input_wrangler=_mean_input_wrangler).skip( matcher=lambda sample: sample.kwargs.get("dim") is not None, reason="this Aten overload only accept 1 inputs: self", ), TorchLibOpInfo( - "mean_dim", - core_ops.aten_mean_dim, - input_wrangler=_mean_input_wrangler, + "mean_dim", core_ops.aten_mean_dim, input_wrangler=_mean_input_wrangler ).skip( matcher=lambda sample: sample.kwargs.get("dim") is None, reason="this Aten overload can accept 2 inputs:(self, dim)", @@ -959,15 +837,11 @@ def _where_input_wrangler( or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), reason="this ATen overload only support one tensor as input and another int as args", ), - TorchLibOpInfo( - "min", - core_ops.aten_min, - ).skip( + TorchLibOpInfo("min", core_ops.aten_min).skip( matcher=lambda sample: len(sample.args) > 0, reason="this ATen overload only supports one tensor as input by design", ), TorchLibOpInfo("minimum", core_ops.aten_minimum), - TorchLibOpInfo("minimum_bool", core_ops.aten_minimum_bool), TorchLibOpInfo("mm", core_ops.aten_mm).skip( matcher=lambda sample: torch.numel(sample.input) == 0, reason="values of matmul of [m, 0] and [0, n] matrices are undefined", @@ -976,39 +850,19 @@ def _where_input_wrangler( TorchLibOpInfo("mT", core_ops.aten_mT_complex, complex=True), TorchLibOpInfo("mul", core_ops.aten_mul), TorchLibOpInfo("mul", core_ops.aten_mul_complex, complex=True), - TorchLibOpInfo( - "mv", - core_ops.aten_mv, - tolerance={torch.float16: (3e-2, 1e-2)}, - ), + TorchLibOpInfo("mv", core_ops.aten_mv, tolerance={torch.float16: (3e-2, 1e-2)}), TorchLibOpInfo("narrow", core_ops.aten_narrow), TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout), TorchLibOpInfo("ne", core_ops.aten_ne), TorchLibOpInfo("neg", core_ops.aten_neg), + TorchLibOpInfo("new_empty", core_ops.aten_new_empty, nondeterministic=True), TorchLibOpInfo( - "new_empty", - core_ops.aten_new_empty, - nondeterministic=True, - ), - TorchLibOpInfo( - "new_empty_strided", - core_ops.aten_new_empty_strided, - nondeterministic=True, - ), - TorchLibOpInfo( - "new_full", - core_ops.aten_new_full, - ), - TorchLibOpInfo( - "new_ones", - core_ops.aten_new_ones, - ), - TorchLibOpInfo( - "new_zeros", - core_ops.aten_new_zeros, + "new_empty_strided", core_ops.aten_new_empty_strided, nondeterministic=True ), + TorchLibOpInfo("new_full", core_ops.aten_new_full), + TorchLibOpInfo("new_ones", core_ops.aten_new_ones), + TorchLibOpInfo("new_zeros", core_ops.aten_new_zeros), TorchLibOpInfo("nn.functional.celu", nn_ops.aten_celu), - TorchLibOpInfo("nn.functional.celu_type_promoted", nn_ops.aten_celu_type_promoted), TorchLibOpInfo( "nn.functional.cross_entropy", # use cross_entropy as test case instead of cross_entropy_loss (not in OPS_DB) @@ -1021,9 +875,7 @@ def _where_input_wrangler( reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[target] as int type", ), TorchLibOpInfo( - "nn.functional.dropout", - core_ops.aten_dropout, - input_wrangler=_dropout_input_wrangler, + "nn.functional.dropout", core_ops.aten_dropout, input_wrangler=_dropout_input_wrangler ).skip( matcher=lambda sample: len(sample.kwargs) == 0 or sample.kwargs.get("p", 0.0) > 0.0, reason="dropout is random so the result not match", @@ -1034,10 +886,7 @@ def _where_input_wrangler( core_ops.aten_embedding_bag, tolerance={torch.float32: (1e-4, 5e-4)}, compare_shape_only_for_output=(1, 2, 3), - ).skip( - dtypes=(torch.float16,), - reason="fixme: results mismatch in torch nightly.", - ), + ).skip(dtypes=(torch.float16,), reason="fixme: results mismatch in torch nightly."), TorchLibOpInfo( "ops.aten.embedding_bag.padding_idx", core_ops.aten_embedding_bag_padding_idx, @@ -1072,10 +921,7 @@ def _where_input_wrangler( tolerance={torch.float16: (5e-2, 1e-2)}, ), TorchLibOpInfo("nn.functional.pad", nn_ops.aten_pad) - .skip( - variant_name="circular", - reason="fixme: ORT does not support the circular mode", - ) + .skip(variant_name="circular", reason="fixme: ORT does not support the circular mode") .skip( variant_name="replicate_negative", reason="fixme: The implementation for negative paddings is not correct", @@ -1097,10 +943,7 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.reflection_pad1d", nn_ops.aten_reflection_pad1d, - ).xfail( - dtypes=(torch.int64,), - reason="Torch not implement reflection_pad1d for int64.", - ), + ).xfail(dtypes=(torch.int64,), reason="Torch not implement reflection_pad1d for int64."), TorchLibOpInfo( "nn.functional.reflection_pad2d", nn_ops.aten_reflection_pad2d, @@ -1109,39 +952,16 @@ def _where_input_wrangler( matcher=lambda sample: not (len(sample.args) > 1 and sample.args[1] == "reflect"), reason="this Aten overload need args[1] == 'reflect' for pad mode", ), - TorchLibOpInfo( - "nn.functional.relu", - nn_ops.aten_relu, - ).xfail( - dtypes=(torch.int64,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT did not implement Relu for int64. https://github.com/microsoft/onnxruntime/issues/16654", - ), - TorchLibOpInfo( - "nn.functional.relu6", - nn_ops.aten_relu6, - ).xfail( - dtypes=(torch.int64,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT did not implement Relu for int64. https://github.com/microsoft/onnxruntime/issues/16654", - ), - TorchLibOpInfo( - "ops.aten.replication_pad1d", - nn_ops.aten_replication_pad1d, - ), + TorchLibOpInfo("nn.functional.relu", nn_ops.aten_relu), + TorchLibOpInfo("nn.functional.relu6", nn_ops.aten_relu6), + TorchLibOpInfo("ops.aten.replication_pad1d", nn_ops.aten_replication_pad1d), TorchLibOpInfo( "nn.functional.replication_pad2d", nn_ops.aten_replication_pad2d, input_wrangler=_replication_pad2d_input_wrangler, - ) - .skip( + ).skip( matcher=lambda sample: not (len(sample.args) > 1 and sample.args[1] == "replicate"), reason="this Aten overload need args[1] == 'replicate' for pad mode", - ) - .xfail( - variant_name="replicate_negative", - enabled_if=not version_utils.torch_older_than("2.2"), - reason="fixme: negative padding is not implemented yet", ), TorchLibOpInfo( "nn.functional.replication_pad3d", @@ -1157,15 +977,9 @@ def _where_input_wrangler( ), TorchLibOpInfo("nn.functional.selu", core_ops.aten_selu), TorchLibOpInfo( - "nn.functional.mse_loss", - nn_ops.aten_mse_loss, - input_wrangler=_mse_loss_input_wrangler, + "nn.functional.mse_loss", nn_ops.aten_mse_loss, input_wrangler=_mse_loss_input_wrangler ), - TorchLibOpInfo( - "nonzero", - core_ops.aten_nonzero, - input_wrangler=_nonzero_input_wrangler, - ) + TorchLibOpInfo("nonzero", core_ops.aten_nonzero, input_wrangler=_nonzero_input_wrangler) .xfail( matcher=lambda sample: sample.kwargs.get("as_tuple"), reason="as_tuple=True is not supported", @@ -1228,26 +1042,19 @@ def _where_input_wrangler( nondeterministic=True, ), TorchLibOpInfo("ops.aten.randn", core_ops.aten_randn, nondeterministic=True).xfail( - dtypes=(torch.float16,), - reason="fixme: Shape inference error", + dtypes=(torch.float16,), reason="fixme: Shape inference error" ), TorchLibOpInfo("ops.aten.randn_like", core_ops.aten_randn_like, nondeterministic=True), TorchLibOpInfo("rad2deg", core_ops.aten_rad2deg), TorchLibOpInfo("reciprocal", core_ops.aten_reciprocal), - TorchLibOpInfo( - "remainder", - core_ops.aten_remainder, - ), + TorchLibOpInfo("remainder", core_ops.aten_remainder), TorchLibOpInfo("repeat", core_ops.aten_repeat), TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_self_int) .skip( matcher=lambda sample: not isinstance(sample.kwargs.get("repeats", None), int), reason=("ignore cases when repeasts is a Tensor"), ) - .skip( - dtypes=(torch.bool,), - reason="bool not supported", - ) + .skip(dtypes=(torch.bool,), reason="bool not supported") .skip( matcher=lambda sample: sample.kwargs.get("dim") is None, reason="fixme: conversion not implemented if dim is None", @@ -1261,10 +1068,7 @@ def _where_input_wrangler( matcher=lambda sample: isinstance(sample.kwargs.get("repeats", None), int), reason=("ignore cases when repeasts is an int"), ) - .skip( - dtypes=(torch.bool,), - reason="bool not supported", - ) + .skip(dtypes=(torch.bool,), reason="bool not supported") .skip( matcher=lambda sample: sample.kwargs.get("dim") is None, reason="fixme: conversion not implemented if dim is None", @@ -1294,14 +1098,9 @@ def _where_input_wrangler( complex=True, ), TorchLibOpInfo( - "ops.aten.scalar_tensor", - core_ops.aten_scalar_tensor_complex, - complex=True, + "ops.aten.scalar_tensor", core_ops.aten_scalar_tensor_complex, complex=True ), - TorchLibOpInfo( - "scatter_add", - core_ops.aten_scatter_add, - ) + TorchLibOpInfo("scatter_add", core_ops.aten_scatter_add) .xfail( matcher=lambda sample: len(sample.input.shape) == 0, reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch. https://github.com/onnx/onnx/issues/4986", @@ -1353,31 +1152,11 @@ def _where_input_wrangler( TorchLibOpInfo( "split_with_sizes", core_ops.aten_split_with_sizes, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( + ).xfail( dtypes=(torch.bool,), reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", ), - TorchLibOpInfo( - "split", - core_ops.aten_split, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - variant_name="list_args", - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", - ) + TorchLibOpInfo("split", core_ops.aten_split) .xfail( dtypes=(torch.bool,), reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", @@ -1388,10 +1167,7 @@ def _where_input_wrangler( reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", ), TorchLibOpInfo("sqrt", core_ops.aten_sqrt), - TorchLibOpInfo( - "squeeze_dim", - core_ops.aten_squeeze_dim, - ) + TorchLibOpInfo("squeeze_dim", core_ops.aten_squeeze_dim) .skip( matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), reason="this Aten overload only support one tensor as input and one int as args by design", @@ -1401,11 +1177,7 @@ def _where_input_wrangler( and sample.input.shape[sample.args[0]] != 1, reason="this Aten overload only support squeeze dim with size 1", ), - TorchLibOpInfo( - "squeeze_dim", - core_ops.aten_squeeze_dim_complex, - complex=True, - ) + TorchLibOpInfo("squeeze_dim", core_ops.aten_squeeze_dim_complex, complex=True) .skip( matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), reason="this Aten overload only support one tensor as input and one int as args by design", @@ -1415,10 +1187,7 @@ def _where_input_wrangler( and sample.input.shape[sample.args[0]] != 1, reason="this Aten overload only support squeeze dim with size 1", ), - TorchLibOpInfo( - "squeeze", - core_ops.aten_squeeze, - ).skip( + TorchLibOpInfo("squeeze", core_ops.aten_squeeze).skip( matcher=lambda sample: len(sample.args) != 0, reason="this Aten overload only support one tensor as input by design", ), @@ -1427,20 +1196,14 @@ def _where_input_wrangler( TorchLibOpInfo("sub", core_ops.aten_sub, tolerance={torch.float16: (2e-3, 1e-3)}), TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True), # TorchLibOpInfo("sym_size", core_ops.aten_sym_size), # no test case in OPS_DB - TorchLibOpInfo( - "t", - core_ops.aten_t, - ).xfail( + TorchLibOpInfo("t", core_ops.aten_t).xfail( enabled_if=not _flags.EXPERIMENTAL_PREFER_TRACING, reason="fixme: ORT Graph attribute inferencing failed on rank-1 input. https://github.com/onnx/onnx/issues/4986", test_class_name="TestOutputConsistencyFullGraph", ), TorchLibOpInfo("tan", core_ops.aten_tan), TorchLibOpInfo("tanh", core_ops.aten_tanh), - TorchLibOpInfo( - "tile", - core_ops.aten_tile, - ).skip( + TorchLibOpInfo("tile", core_ops.aten_tile).skip( matcher=lambda sample: any(dim == 0 for dim in sample.input.shape) or not sample.input.shape, reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", @@ -1468,16 +1231,7 @@ def _where_input_wrangler( reason="fixme: ORT does not have an implementation of Trilu for int32.", ), TorchLibOpInfo("trunc", core_ops.aten_trunc), - TorchLibOpInfo( - "unbind", - core_ops.aten_unbind, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: SplitToSequence op inference failed. https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( + TorchLibOpInfo("unbind", core_ops.aten_unbind).xfail( dtypes=(torch.bool,), reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", ), @@ -1499,10 +1253,7 @@ def _where_input_wrangler( ), TorchLibOpInfo("xlogy", special_ops.aten_special_xlogy), TorchLibOpInfo("zeros", core_ops.aten_zeros), - TorchLibOpInfo( - "arange_start_step", - core_ops.aten_arange_start_step, - ) + TorchLibOpInfo("arange_start_step", core_ops.aten_arange_start_step) .skip( matcher=lambda sample: len(sample.args) != 2, reason="arange_start_step overload takes three arguments (input, start, step)", @@ -1512,10 +1263,7 @@ def _where_input_wrangler( reason="dtype needs to be specified for non-float tensors", dtypes=(torch.float16, torch.int64, torch.int32), ), - TorchLibOpInfo( - "arange_start", - core_ops.aten_arange_start, - ) + TorchLibOpInfo("arange_start", core_ops.aten_arange_start) .skip( matcher=lambda sample: len(sample.args) != 1, reason="arange_start overload takes two arguments (input, start)", @@ -1525,10 +1273,7 @@ def _where_input_wrangler( reason="dtype needs to be specified for non-float tensors", dtypes=(torch.float16, torch.int64, torch.int32), ), - TorchLibOpInfo( - "arange", - core_ops.aten_arange, - ) + TorchLibOpInfo("arange", core_ops.aten_arange) .xfail( dtypes=(torch.int32,), reason="fixme: output shape mismatch in edge cases. https://github.com/microsoft/onnxscript/issues/974", @@ -1551,10 +1296,7 @@ def _where_input_wrangler( TorchLibOpInfo( "as_strided", core_ops.aten_as_strided, - ).xfail( - variant_name="partial_views", - reason="ONNX doesn't have partial view for tensor", - ), + ).xfail(variant_name="partial_views", reason="ONNX doesn't have partial view for tensor"), TorchLibOpInfo("clamp", core_ops.aten_clamp_tensor), TorchLibOpInfo( "ops.aten.col2im", @@ -1574,19 +1316,13 @@ def _where_input_wrangler( tolerance={torch.float32: (2e-4, 9e-4)}, ), TorchLibOpInfo("empty_like", core_ops.aten_empty_like, nondeterministic=True), - TorchLibOpInfo( - "grid_sampler_2d", - core_ops.aten_grid_sampler_2d, - ) + TorchLibOpInfo("grid_sampler_2d", core_ops.aten_grid_sampler_2d) .skip( # Torch implemented this using the cubic convolution algorithm with alhpa=-0.75, might be different than ORT matcher=lambda sample: sample.args[1] == 2, reason="fixme: 'bicubic' mode in ORT implemented differently with Torch", ) - .skip( - dtypes=(torch.float16,), - reason="fixme: Accuracy is not high enough", - ), + .skip(dtypes=(torch.float16,), reason="fixme: Accuracy is not high enough"), TorchLibOpInfo( "nn.functional.group_norm", nn_ops.aten_group_norm, @@ -1647,10 +1383,7 @@ def _where_input_wrangler( or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), reason="this ATen overload only support one tensor as input and another int as args", ), - TorchLibOpInfo( - "max", - core_ops.aten_max, - ).skip( + TorchLibOpInfo("max", core_ops.aten_max).skip( matcher=lambda sample: len(sample.args) > 0, reason="this ATen overload only supports one tensor as input by design", ), @@ -1708,8 +1441,7 @@ def _where_input_wrangler( reason="fixme: ORT only supports BatchNorm less than opset14", ), TorchLibOpInfo( - "ops.aten._native_batch_norm_legit.no_stats", - core_ops.aten__native_batch_norm_no_stats, + "ops.aten._native_batch_norm_legit.no_stats", core_ops.aten__native_batch_norm_no_stats ), TorchLibOpInfo( "ops.aten._native_batch_norm_legit_functional", @@ -1730,10 +1462,6 @@ def _where_input_wrangler( "ops.aten.native_group_norm", core_ops.aten_native_group_norm, tolerance={torch.float16: (1e-2, 7e-3)}, - ).xfail( - dtypes=(torch.float16,), - reason="fixme: 'GroupNormKernelImpl' not implemented for 'Half' in nightly and weekly", - enabled_if=version_utils.torch_older_than("2.2"), ), TorchLibOpInfo( "native_layer_norm", @@ -1815,9 +1543,7 @@ def _where_input_wrangler( tolerance={torch.float16: (1e-2, 1e-3)}, ), TorchLibOpInfo( - "ops.aten.conv3d", - core_ops.aten_conv3d, - tolerance={torch.float32: (3.7e-5, 1.8e-4)}, + "ops.aten.conv3d", core_ops.aten_conv3d, tolerance={torch.float32: (3.7e-5, 1.8e-4)} ), TorchLibOpInfo("nn.functional.gelu", nn_ops.aten_gelu), TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu), @@ -1898,11 +1624,6 @@ def _where_input_wrangler( nn_ops.aten_scaled_dot_product_attention, tolerance={torch.float32: (3e-4, 1.5e-5)}, ) - .skip( - matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None - and attn_mask.dtype == torch.bool, - reason="this overload takes a non-boolean mask", - ) .skip( matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, reason="dropout is random so the results do not match", @@ -1925,15 +1646,7 @@ def _where_input_wrangler( # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, compare_shape_only_for_output=(1, 2, 3, 4, 5, 6, 7, 8), - ) - .skip( - enabled_if=version_utils.torch_older_than("2.1"), - reason="The operator is not supported in older version.", - ) - .skip( - device_type="cpu", - reason="_scaled_dot_product_flash_attention only supports CUDA", - ), + ).skip(device_type="cpu", reason="_scaled_dot_product_flash_attention only supports CUDA"), TorchLibOpInfo( "ops.aten._scaled_dot_product_efficient_attention", nn_ops.aten__scaled_dot_product_efficient_attention, @@ -1941,40 +1654,10 @@ def _where_input_wrangler( # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, compare_shape_only_for_output=(1, 2, 3), - ) - .skip( - enabled_if=version_utils.torch_older_than("2.1"), - reason="The operator is not supported in older version.", - ) - .skip( + ).skip( enabled_if=not torch.cuda.is_available(), reason="_scaled_dot_product_efficient_attention only supports CUDA", ), - TorchLibOpInfo( - "nn.functional.scaled_dot_product_attention_bool_mask", - nn_ops.aten_scaled_dot_product_attention_bool_mask, - tolerance={torch.float32: (3e-4, 1.5e-5)}, - ) - .skip( - matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None - and attn_mask.dtype != torch.bool, - reason="this overload takes a boolean mask", - ) - .skip( - matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, - reason="dropout is random so the results do not match", - ) - .xfail( - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - test_class_name="TestOutputConsistencyFullGraph", - ) - .xfail( - matcher=lambda sample: len(sample.input.shape) != 4 - or len(sample.args[0].shape) != 4 - or len(sample.args[1].shape) != 4, - reason="torch sdpa is expected to pass in 4d q, k, and v.", - ), TorchLibOpInfo( "ops.aten.upsample_bilinear2d.default", nn_ops.aten_upsample_bilinear2d, @@ -1994,10 +1677,7 @@ def _where_input_wrangler( # Shape-only comparison is the appropriate testing approach for this case. compare_shape_only_for_output=(0,), ), - TorchLibOpInfo( - "ops.aten.upsample_bilinear2d.vec", - nn_ops.aten_upsample_bilinear2d_vec, - ), + TorchLibOpInfo("ops.aten.upsample_bilinear2d.vec", nn_ops.aten_upsample_bilinear2d_vec), TorchLibOpInfo( "ops.aten.upsample_bicubic2d.default", nn_ops.aten_upsample_bicubic2d, @@ -2017,10 +1697,7 @@ def _where_input_wrangler( # Shape-only comparison is the appropriate testing approach for this case. compare_shape_only_for_output=(0,), ), - TorchLibOpInfo( - "ops.aten.upsample_bicubic2d.vec", - nn_ops.aten_upsample_bicubic2d_vec, - ), + TorchLibOpInfo("ops.aten.upsample_bicubic2d.vec", nn_ops.aten_upsample_bicubic2d_vec), TorchLibOpInfo( "ops.aten.upsample_linear1d", nn_ops.aten_upsample_linear1d, @@ -2029,59 +1706,22 @@ def _where_input_wrangler( and sample.kwargs.get("scales") is not None, reason="fixme: align_corners=False output mismatch when scales are provided", ), - TorchLibOpInfo( - "ops.aten.upsample_nearest1d", - nn_ops.aten_upsample_nearest1d, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest1d.vec", - nn_ops.aten_upsample_nearestnd_vec, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest2d", - nn_ops.aten_upsample_nearest2d, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest2d.vec", - nn_ops.aten_upsample_nearestnd_vec, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest3d", - nn_ops.aten_upsample_nearest3d, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest3d.vec", - nn_ops.aten_upsample_nearestnd_vec, - ), - TorchLibOpInfo( - "ops.aten.upsample_trilinear3d.default", - nn_ops.aten_upsample_trilinear3d, - ), - TorchLibOpInfo( - "ops.aten.upsample_trilinear3d.vec", - nn_ops.aten_upsample_trilinear3d_vec, - ), + TorchLibOpInfo("ops.aten.upsample_nearest1d", nn_ops.aten_upsample_nearest1d), + TorchLibOpInfo("ops.aten.upsample_nearest1d.vec", nn_ops.aten_upsample_nearestnd_vec), + TorchLibOpInfo("ops.aten.upsample_nearest2d", nn_ops.aten_upsample_nearest2d), + TorchLibOpInfo("ops.aten.upsample_nearest2d.vec", nn_ops.aten_upsample_nearestnd_vec), + TorchLibOpInfo("ops.aten.upsample_nearest3d", nn_ops.aten_upsample_nearest3d), + TorchLibOpInfo("ops.aten.upsample_nearest3d.vec", nn_ops.aten_upsample_nearestnd_vec), + TorchLibOpInfo("ops.aten.upsample_trilinear3d.default", nn_ops.aten_upsample_trilinear3d), + TorchLibOpInfo("ops.aten.upsample_trilinear3d.vec", nn_ops.aten_upsample_trilinear3d_vec), TorchLibOpInfo("ones_like", core_ops.aten_ones_like), - TorchLibOpInfo( - "roll", - core_ops.aten_roll, - input_wrangler=_roll_input_wrangler, - ), - TorchLibOpInfo( - "roll", - core_ops.aten_roll_complex, - input_wrangler=_roll_input_wrangler, - complex=True, - ), + TorchLibOpInfo("roll", core_ops.aten_roll, input_wrangler=_roll_input_wrangler), TorchLibOpInfo( "scatter_reduce", core_ops.aten_scatter_reduce, input_wrangler=_scatter_reduce_input_wrangler, ) - .xfail( - variant_name="mean", - reason="ONNX doesn't support reduce='mean' option", - ) + .xfail(variant_name="mean", reason="ONNX doesn't support reduce='mean' option") .xfail( variant_name="prod", dtypes=(torch.float16, torch.float64), @@ -2153,40 +1793,12 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "atleast_1d", ("atleast_1d_Sequence",)) ops_test_common.duplicate_opinfo(OPS_DB, "atleast_2d", ("atleast_2d_Sequence",)) ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_Sequence",)) -ops_test_common.duplicate_opinfo( - OPS_DB, - "bitwise_left_shift", - ( - "bitwise_left_shift_int8", - "bitwise_left_shift_int16", - "bitwise_left_shift_int32", - "bitwise_left_shift_int64", - ), -) -ops_test_common.duplicate_opinfo( - OPS_DB, - "bitwise_right_shift", - ( - "bitwise_right_shift_int8", - "bitwise_right_shift_int16", - "bitwise_right_shift_int32", - "bitwise_right_shift_int64", - ), -) ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate")) ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",)) -ops_test_common.duplicate_opinfo(OPS_DB, "diagonal", ("diagonal_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode", "div_mode_int")) -ops_test_common.duplicate_opinfo(OPS_DB, "ge", ("ge_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "gt", ("gt_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "le", ("le_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "lt", ("lt_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",)) -ops_test_common.duplicate_opinfo(OPS_DB, "maximum", ("maximum_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",)) -ops_test_common.duplicate_opinfo(OPS_DB, "minimum", ("minimum_bool",)) ops_test_common.duplicate_opinfo( OPS_DB, "nn.functional.pad", @@ -2196,20 +1808,6 @@ def _where_input_wrangler( "nn.functional.replication_pad3d", ), ) -ops_test_common.duplicate_opinfo( - OPS_DB, - "nn.functional.scaled_dot_product_attention", - ("nn.functional.scaled_dot_product_attention_bool_mask",), -) -ops_test_common.duplicate_opinfo( - OPS_DB, - "nn.functional.celu", - ("nn.functional.celu_type_promoted",), -) -ops_test_common.duplicate_opinfo( - OPS_DB, "ops.aten._log_softmax", ("ops.aten._log_softmax_half",) -) -ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",)) ops_test_common.duplicate_opinfo(OPS_DB, "prod", ("prod_dim_int",)) ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",)) From 8c0ed92a83b9400004d008a05deebc38c314dd40 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 15:27:44 -0700 Subject: [PATCH 10/23] Fix tests Signed-off-by: Justin Chu --- tests/function_libs/torch_lib/ops_test.py | 6 ++++-- tests/function_libs/torch_lib/ops_test_data.py | 3 +-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index 7ba6f9d37f..fb68daa111 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -98,7 +98,7 @@ def _should_skip_xfail_test_sample( class TestFunctionValidity(unittest.TestCase): @parameterized.parameterized.expand( - [(info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] + [(info.op_info_name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] ) def test_script_function_passes_checker( self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo @@ -109,10 +109,12 @@ def test_script_function_passes_checker( onnx.checker.check_function(function_proto) # type: ignore[attr-defined] @parameterized.parameterized.expand( - [(info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] + [(info.op_info_name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] ) def test_function_has_op_schema(self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo): func = torchlib_op_info.op + if not hasattr(func, "op_schema"): + raise AssertionError(f"Function {func.__name__} does not have op_schema attribute") schema = func.op_schema self.assertIsNotNone(schema) self.assertEqual(schema.name, func.name) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 021c92e6b5..be0746a2a0 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -48,7 +48,6 @@ from torch.testing._internal.opinfo import definitions as opinfo_definitions from typing_extensions import Self -from onnxscript._internal import version_utils from onnxscript.function_libs.torch_lib import _flags from onnxscript.function_libs.torch_lib.ops import core as core_ops from onnxscript.function_libs.torch_lib.ops import fft as fft_ops @@ -1715,7 +1714,6 @@ def _where_input_wrangler( TorchLibOpInfo("ops.aten.upsample_trilinear3d.default", nn_ops.aten_upsample_trilinear3d), TorchLibOpInfo("ops.aten.upsample_trilinear3d.vec", nn_ops.aten_upsample_trilinear3d_vec), TorchLibOpInfo("ones_like", core_ops.aten_ones_like), - TorchLibOpInfo("roll", core_ops.aten_roll, input_wrangler=_roll_input_wrangler), TorchLibOpInfo( "scatter_reduce", core_ops.aten_scatter_reduce, @@ -1795,6 +1793,7 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_Sequence",)) ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate")) ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",)) +ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",)) ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",)) From a8acb8f8728fdd343ff08357f89670e5ec4bfaa2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 15:33:37 -0700 Subject: [PATCH 11/23] Bump test dependencies Signed-off-by: Justin Chu --- noxfile.py | 8 +++---- .../function_libs/torch_lib/ops/core.py | 2 +- .../function_libs/torch_lib/ops_test_data.py | 21 +++---------------- 3 files changed, 8 insertions(+), 23 deletions(-) diff --git a/noxfile.py b/noxfile.py index ac9296a5cd..d7e3328d6a 100644 --- a/noxfile.py +++ b/noxfile.py @@ -29,10 +29,10 @@ "typing_extensions>=4.10", "ml-dtypes", ) -ONNX = "onnx==1.17" -ONNX_RUNTIME = "onnxruntime==1.20.1" -PYTORCH = "torch==2.5.1" -TORCHVISON = "torchvision==0.20.1" +ONNX = "onnx==1.18" +ONNX_RUNTIME = "onnxruntime==1.23.0" +PYTORCH = "torch==2.6.0" +TORCHVISON = "torchvision==0.21.0" TRANSFORMERS = "transformers==4.37.2" ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = ( "flatbuffers", diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index c2cb6544ce..57497fde44 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1316,7 +1316,7 @@ def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt: elif self.dtype.bitwidth == 64: unsigned_dtype = ir.DataType.UINT64 signed_dtype = ir.DataType.INT64 - mask = ir.tensor(-1, dtype=unsigned_dtype) # 0xFFFFFFFFFFFFFFFF + mask = ir.tensor(0xFFFFFFFFFFFFFFFF, dtype=unsigned_dtype) # 0xFFFFFFFFFFFFFFFF else: raise NotImplementedError(f"Not implemented for type {self.dtype}") diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index be0746a2a0..005574aedf 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1150,21 +1150,9 @@ def _where_input_wrangler( ), TorchLibOpInfo( "split_with_sizes", - core_ops.aten_split_with_sizes, - ).xfail( - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ), - TorchLibOpInfo("split", core_ops.aten_split) - .xfail( - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ) - .xfail( - variant_name="list_args", - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", + core_ops.aten_split_with_sizes ), + TorchLibOpInfo("split", core_ops.aten_split), TorchLibOpInfo("sqrt", core_ops.aten_sqrt), TorchLibOpInfo("squeeze_dim", core_ops.aten_squeeze_dim) .skip( @@ -1230,10 +1218,7 @@ def _where_input_wrangler( reason="fixme: ORT does not have an implementation of Trilu for int32.", ), TorchLibOpInfo("trunc", core_ops.aten_trunc), - TorchLibOpInfo("unbind", core_ops.aten_unbind).xfail( - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ), + TorchLibOpInfo("unbind", core_ops.aten_unbind), TorchLibOpInfo("unflatten", core_ops.aten_unflatten), TorchLibOpInfo("unfold", core_ops.aten_unfold), TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold), From 376959efeb21af4e1304862a7fa71494eb06d3e2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 15:41:19 -0700 Subject: [PATCH 12/23] lint Signed-off-by: Justin Chu --- tests/function_libs/torch_lib/ops_test_data.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 005574aedf..41fbbb80a1 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1148,10 +1148,7 @@ def _where_input_wrangler( dtypes=(torch.float16,), reason="fixme: Tensor-likes are not close. Tests pass for float32.", ), - TorchLibOpInfo( - "split_with_sizes", - core_ops.aten_split_with_sizes - ), + TorchLibOpInfo("split_with_sizes", core_ops.aten_split_with_sizes), TorchLibOpInfo("split", core_ops.aten_split), TorchLibOpInfo("sqrt", core_ops.aten_sqrt), TorchLibOpInfo("squeeze_dim", core_ops.aten_squeeze_dim) From 4c505385e2628a88d9436f673d0fa0b044f9d4a2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 16:45:00 -0700 Subject: [PATCH 13/23] go back Signed-off-by: Justin Chu --- noxfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index d7e3328d6a..fa0221c2d8 100644 --- a/noxfile.py +++ b/noxfile.py @@ -29,7 +29,7 @@ "typing_extensions>=4.10", "ml-dtypes", ) -ONNX = "onnx==1.18" +ONNX = "onnx==1.17" ONNX_RUNTIME = "onnxruntime==1.23.0" PYTORCH = "torch==2.6.0" TORCHVISON = "torchvision==0.21.0" From d50b92f24e279432f3bfbe677044c48fd61f1530 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 3 Oct 2025 17:02:53 -0700 Subject: [PATCH 14/23] scalar_tensor Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 57497fde44..29aff73ecd 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7284,7 +7284,7 @@ def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: @torch_op("aten::scalar_tensor", trace_only=True) def aten_scalar_tensor( - s, + s: float, dtype: int = FLOAT.dtype, layout: str = "", device: str = "", From 12599cfe75f9df2cf7779fae30d54ce380db5fc3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 7 Oct 2025 11:47:58 -0700 Subject: [PATCH 15/23] Fix roll Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 99 ++++++++++++++++++- .../function_libs/torch_lib/ops_test_data.py | 11 +++ 2 files changed, 108 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 6b356c6aae..a01f14b6e1 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7223,11 +7223,106 @@ def aten_rnn_tanh_cell( raise NotImplementedError() -# roll is decomposed by PyTorch +@torch_op("aten::roll", trace_only=True) def aten_roll(self: TTensor, shifts: Sequence[int], dims: Sequence[int] = ()) -> TTensor: """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor""" - raise NotImplementedError() + if isinstance(shifts, int): + shifts = [shifts] + + if isinstance(dims, int): + dims = [dims] + + self_rank = len(self.shape) + if self_rank == 0: + return op.Identity(self) + elif self.shape[0] == 0: # empty tensor + return op.Identity(self) + + # NOTE: In pytorch, default value of dims is an empty list. + if len(dims) == 0: # Empty sequence + assert len(shifts) == 1, "shifts should be a single integer if dims is empty" + return _aten_roll_shift_no_dim_onnx(self, shifts[0]) + else: + assert len(shifts) == len(dims) + result = self + for i, shift in enumerate(shifts): + dim = dims[i] + result = _aten_roll_shift_and_dim_onnx(result, shift, dim) + return result + + +@torch_op("aten::roll", trace_only=True, complex=True) +def aten_roll_complex( + self: TTensor, shifts: Sequence[int], dims: Sequence[int] = () +) -> TTensor: + """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor""" + + if isinstance(shifts, int): + shifts = [shifts] + + if isinstance(dims, int): + dims = [dims] + + self_rank = len(self.shape) + if self_rank == 1: + return op.Identity(self) + + if self.shape[0] == 0: # empty tensor + return op.Identity(self) + + self_real = op.Slice(self, [0], [1], axes=[-1]) + self_imag = op.Slice(self, [1], [2], axes=[-1]) + if not dims: + assert len(shifts) == 1, "shifts should be a single integer if dims is empty" + shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts[0]) + shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts[0]) + + result = op.Concat(shift_real, shift_imag, axis=-1) + + else: + assert len(shifts) == len(dims) + for i, dim in enumerate(dims): + self_real = _aten_roll_shift_and_dim_onnx(self_real, shifts[i], dim) + self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shifts[i], dim) + + result = op.Concat(self_real, self_imag, axis=-1) + return result + + +def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: int) -> TTensor: + neg_1 = op.Constant(value_ints=[-1]) + # flatten the self tensor: from [[A,B],[C,D]] to [A,B,C,D] + self_flatten = op.Reshape(self, neg_1) + # Compute slice length + if shift < 0: + # For [A,B,C,D], if shift is -1, slice_length = -(-1) = 1, means move [A] to the end + slice_length = op.Constant(value_ints=[-shift]) + else: + # For [A,B,C,D], if shift is 1, slice_length = 4 - 1 = 3, means move [A,B,C] to the end + # The effect equals to move [D] to the beginning + slice_length = op.Size(self_flatten) - op.Constant(value_ints=[shift]) + # Get second part of the tensor, e.g. [A,B,C] + suffix = op.Slice(self_flatten, op.Constant(value_ints=[0]), slice_length) + # Get first part of the tensor, e.g. [D] + prefix = op.Slice(self_flatten, slice_length, op.Reshape(op.Size(self_flatten), neg_1)) + # Concat first+second together, e.g. [D,A,B,C] + result = op.Concat(prefix, suffix, axis=0) + return op.Reshape(result, op.Shape(self)) + + +def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: int, dim: int) -> TTensor: + neg_1 = op.Constant(value_ints=[-1]) + dim_tensor = op.Constant(value_ints=[dim]) + if shift < 0: + slice_length = op.Constant(value_ints=[-shift]) + else: + slice_length = op.Shape(self, start=dim, end=dim + 1) - op.Constant(value_ints=[shift]) + # from [A,B,C,D] -> [D,A,B,C], [D] is prefix, [A,B,C] is suffix + suffix = op.Slice(self, op.Constant(value_ints=[0]), slice_length, axes=dim_tensor) + prefix = op.Slice(self, slice_length, op.Reshape(op.Size(self), neg_1), axes=dim_tensor) + result = op.Concat(prefix, suffix, axis=dim) + return result def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> TensorType: diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 4ec336348e..bacd24ad1c 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1699,6 +1699,17 @@ def _where_input_wrangler( TorchLibOpInfo("ops.aten.upsample_trilinear3d.default", nn_ops.aten_upsample_trilinear3d), TorchLibOpInfo("ops.aten.upsample_trilinear3d.vec", nn_ops.aten_upsample_trilinear3d_vec), TorchLibOpInfo("ones_like", core_ops.aten_ones_like), + TorchLibOpInfo( + "roll", + core_ops.aten_roll, + input_wrangler=_roll_input_wrangler, + ), + TorchLibOpInfo( + "roll", + core_ops.aten_roll_complex, + input_wrangler=_roll_input_wrangler, + complex=True, + ), TorchLibOpInfo( "scatter_reduce", core_ops.aten_scatter_reduce, From 40a0ee1d3a80d146f9ca18701f10fb5c48d052b2 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 7 Oct 2025 11:50:02 -0700 Subject: [PATCH 16/23] pytorch 2.7 Signed-off-by: Justin Chu --- noxfile.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/noxfile.py b/noxfile.py index 3e1e85ee8a..60c2bb901b 100644 --- a/noxfile.py +++ b/noxfile.py @@ -30,8 +30,8 @@ ) ONNX = "onnx==1.17" ONNX_RUNTIME = "onnxruntime==1.23.0" -PYTORCH = "torch==2.6.0" -TORCHVISON = "torchvision==0.21.0" +PYTORCH = "torch==2.7.1" +TORCHVISON = "torchvision==0.22.1" TRANSFORMERS = "transformers==4.37.2" ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = ( "flatbuffers", From 30c7609f28287f87e08b0aea5d349b1933404e46 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 7 Oct 2025 12:15:56 -0700 Subject: [PATCH 17/23] skip negative tests Signed-off-by: Justin Chu --- tests/function_libs/torch_lib/ops_test_data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index bacd24ad1c..597b26a1da 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -964,6 +964,9 @@ def _where_input_wrangler( ).skip( matcher=lambda sample: not (len(sample.args) > 1 and sample.args[1] == "replicate"), reason="this Aten overload need args[1] == 'replicate' for pad mode", + ).skip( + variant_name="replicate_negative", + reason="fixme: The implementation for negative paddings is not correct. Potentially an ORT issue", ), TorchLibOpInfo( "nn.functional.replication_pad3d", From 9211f2674f63beab9ece8e39579a4bcf563084cd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 7 Oct 2025 12:21:13 -0700 Subject: [PATCH 18/23] group norm Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 4 +++- tests/function_libs/torch_lib/ops_test_data.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a01f14b6e1..ca22d42ebd 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6023,7 +6023,9 @@ def aten_native_group_norm( sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean) # In Pytorch, vstd = 1/(sqrt(var + eps)) var = op.ReduceMean(sqr_input_sub_mean, axes, keepdims=False) - rstd = op.Div(1.0, op.Sqrt(var + eps)) + eps = op.Constant(value=ir.tensor(eps, dtype=input.dtype)) + one = op.Constant(value=ir.tensor(1.0, dtype=input.dtype)) + rstd = op.Div(one, op.Sqrt(op.Add(var, eps))) # Get the correct shape [N, group] for mean again mean = op.ReduceMean(input_N_group_neg1, axes, keepdims=False) return norm_result, mean, rstd diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 597b26a1da..df932c9c74 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -961,10 +961,12 @@ def _where_input_wrangler( "nn.functional.replication_pad2d", nn_ops.aten_replication_pad2d, input_wrangler=_replication_pad2d_input_wrangler, - ).skip( + ) + .skip( matcher=lambda sample: not (len(sample.args) > 1 and sample.args[1] == "replicate"), reason="this Aten overload need args[1] == 'replicate' for pad mode", - ).skip( + ) + .skip( variant_name="replicate_negative", reason="fixme: The implementation for negative paddings is not correct. Potentially an ORT issue", ), From 7dd4cebbd72bc3a4c142adcec03b4e8d8c8b7e2d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 7 Oct 2025 12:22:19 -0700 Subject: [PATCH 19/23] softmax Signed-off-by: Justin Chu --- tests/function_libs/torch_lib/ops_test_data.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index df932c9c74..d5498c4b34 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -459,7 +459,11 @@ def _where_input_wrangler( tolerance={torch.float64: (2e-6, 2e-6), torch.float32: (3e-2, 3e-4)}, ), TorchLibOpInfo("ops.aten._local_scalar_dense", core_ops.aten__local_scalar_dense), - TorchLibOpInfo("ops.aten._log_softmax", core_ops.aten__log_softmax), + TorchLibOpInfo( + "ops.aten._log_softmax", + core_ops.aten__log_softmax, + tolerance={torch.float16: (1e-3, 1e-3)}, + ), TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax), TorchLibOpInfo("all_dim", core_ops.aten_all_dim).skip( matcher=lambda sample: not (len(sample.kwargs) > 0) From 6ba6f521fb49b6b31ea28539b4e55add99616c7d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 7 Oct 2025 12:24:14 -0700 Subject: [PATCH 20/23] linspace Signed-off-by: Justin Chu --- tests/function_libs/torch_lib/ops_test_data.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index d5498c4b34..b60fd8cf31 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -766,9 +766,14 @@ def _where_input_wrangler( "linspace", core_ops.aten_linspace, tolerance={torch.float16: (2e-2, 2e-3)}, - ).xfail( + ) + .xfail( dtypes=(torch.int64, torch.int32), reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", + ) + .skip( + matcher=lambda sample: sample.kwargs.get("dtype") in (torch.int64, torch.int32), + reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", ), TorchLibOpInfo("log", core_ops.aten_log), TorchLibOpInfo("le", core_ops.aten_le), From 7b349755941a75535e4b00fb86a4617f82483bc6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 8 Oct 2025 08:55:07 -0700 Subject: [PATCH 21/23] Skip tests Signed-off-by: Justin Chu --- onnxscript/backend/onnx_export_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py index 49eb398750..1f913ed897 100644 --- a/onnxscript/backend/onnx_export_test.py +++ b/onnxscript/backend/onnx_export_test.py @@ -84,6 +84,7 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True): ), skip(r"^test_ai_onnx_ml_label_encoder", "ONNX Runtime does not support Opset 21 at 1.17"), skip(r"^test_ai_onnx_ml_tree_ensemble", "Opset 23 is not supported"), + skip(r"^test_attention", "ONNX Runtime 1.23 fails on these tests"), ) if sys.platform == "win32": From cdc917d1a82cb30892ed329257537bfe9ac0ffcc Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 8 Oct 2025 09:45:17 -0700 Subject: [PATCH 22/23] Update requirements-ort-nightly.txt --- requirements/ci/requirements-ort-nightly.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/ci/requirements-ort-nightly.txt b/requirements/ci/requirements-ort-nightly.txt index 4ed908b4e2..b54550738b 100644 --- a/requirements/ci/requirements-ort-nightly.txt +++ b/requirements/ci/requirements-ort-nightly.txt @@ -1,3 +1,3 @@ # https://aiinfra.visualstudio.com/PublicPackages/_artifacts/feed/ORT-Nightly/PyPI/onnxruntime/overview --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ -onnxruntime==1.23.0.dev20250517001 +onnxruntime==1.23.0.dev20251001001 From fbc054043f9d417b3ef8b0e1df2e40efac6fe692 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 8 Oct 2025 10:04:03 -0700 Subject: [PATCH 23/23] Update nn.py --- onnxscript/function_libs/torch_lib/ops/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 80be5655f4..2a7a46ec28 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1800,7 +1800,7 @@ def aten_scaled_dot_product_attention( query: TFloat, key: TFloat, value: TFloat, - attn_mask: Optional[TFloat] = None, + attn_mask: Optional[TensorType] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None,