Skip to content

Commit c0abfbc

Browse files
committed
Merge branch 'main' into justinchu/consolidate-overloads
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
2 parents d50b92f + 8e449da commit c0abfbc

File tree

12 files changed

+217
-67
lines changed

12 files changed

+217
-67
lines changed

.github/workflows/lint.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ jobs:
4545
steps:
4646
- uses: actions/checkout@v5
4747
- name: Setup Python
48-
uses: actions/setup-python@v5
48+
uses: actions/setup-python@v6
4949
with:
5050
# Version range or exact version of Python to use, using SemVer's version range syntax. Reads from .python-version if unset.
5151
python-version: "3.10"

.github/workflows/main.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ jobs:
5959
steps:
6060
- uses: actions/checkout@v5
6161
- name: Setup Python ${{ matrix.python-version }}
62-
uses: actions/setup-python@v5
62+
uses: actions/setup-python@v6
6363
with:
6464
python-version: ${{ matrix.python-version }}
6565
- name: Install nox
@@ -97,7 +97,7 @@ jobs:
9797
steps:
9898
- uses: actions/checkout@v5
9999
- name: Setup Python
100-
uses: actions/setup-python@v5
100+
uses: actions/setup-python@v6
101101
with:
102102
python-version: "3.10"
103103
cache: pip
@@ -121,7 +121,7 @@ jobs:
121121
steps:
122122
- uses: actions/checkout@v5
123123
- name: Setup Python
124-
uses: actions/setup-python@v5
124+
uses: actions/setup-python@v6
125125
- name: Update readme
126126
run: |
127127
python docs/update_readme.py

.github/workflows/pages.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
- name: Setup Pages
3030
uses: actions/configure-pages@v4
3131
- name: Setup Python
32-
uses: actions/setup-python@v5
32+
uses: actions/setup-python@v6
3333
with:
3434
python-version: "3.10"
3535
- uses: actions/checkout@v5

noxfile.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313

1414
COMMON_TEST_DEPENDENCIES = (
15-
"beartype==0.17.2",
1615
"expecttest==0.1.6",
1716
"hypothesis",
1817
"numpy",

onnxscript/_internal/runtime_typing.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,6 +1161,7 @@ def aten_bernoulli_p(self: TTensor, p: float) -> TTensor:
11611161
return op.CastLike(sampled, self)
11621162

11631163

1164+
@torch_op("aten::bilinear", trace_only=True)
11641165
def aten_bilinear(
11651166
input1: TensorType,
11661167
input2: TensorType,
@@ -1169,7 +1170,23 @@ def aten_bilinear(
11691170
) -> TensorType:
11701171
"""bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor"""
11711172

1172-
raise NotImplementedError()
1173+
# Bilinear transformation: y = x1^T A x2 + b
1174+
# input1 shape: (..., in1_features)
1175+
# input2 shape: (..., in2_features)
1176+
# weight shape: (out_features, in1_features, in2_features)
1177+
# bias shape: (out_features) - optional
1178+
# output shape: (..., out_features)
1179+
1180+
# Use Einsum to compute the bilinear transformation
1181+
# "...i,oij,...j->...o" means:
1182+
# - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o]
1183+
result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o")
1184+
1185+
# Add bias if provided
1186+
if bias is not None:
1187+
result = op.Add(result, bias)
1188+
1189+
return result
11731190

11741191

11751192
def aten_binary_cross_entropy_with_logits(
@@ -7284,7 +7301,7 @@ def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
72847301

72857302
@torch_op("aten::scalar_tensor", trace_only=True)
72867303
def aten_scalar_tensor(
7287-
s: float,
7304+
s: TensorType,
72887305
dtype: int = FLOAT.dtype,
72897306
layout: str = "",
72907307
device: str = "",
@@ -7322,17 +7339,35 @@ def aten_scalar_tensor_complex(
73227339
return result
73237340

73247341

7325-
@torch_op(("aten::scatter.value", "aten::scatter.src"), trace_only=True)
7326-
def aten_scatter(
7327-
self: TReal,
7342+
@torch_op("aten::scatter.src", trace_only=True)
7343+
def aten_scatter_src(
7344+
self: TTensor,
73287345
dim: int, # we have to use int here because ScatterElements() will use this attribute
73297346
index: TInt,
7330-
src: TReal,
7331-
) -> TReal:
7332-
"""scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""
7347+
src: TTensor,
7348+
) -> TTensor:
7349+
"""scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"""
7350+
if len(index.shape) == 0:
7351+
index = op.Unsqueeze(index, [0])
7352+
if len(src.shape) == 0:
7353+
src = op.Unsqueeze(src, [0])
7354+
return op.ScatterElements(self, index, src, axis=dim)
73337355

7334-
update = op.Expand(src, op.Shape(index))
7335-
return op.ScatterElements(self, index, update, axis=dim)
7356+
7357+
@torch_op("aten::scatter.value", trace_only=True)
7358+
def aten_scatter_value(
7359+
self: TTensor,
7360+
dim: int, # we have to use int here because ScatterElements() will use this attribute
7361+
index: TInt,
7362+
value: float,
7363+
) -> TTensor:
7364+
"""scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor"""
7365+
# Ensure value is a scalar tensor and expand it to match index shape
7366+
if len(index.shape) == 0:
7367+
index = op.Unsqueeze(index, [0])
7368+
scalar_tensor = ir.tensor([value], dtype=self.dtype)
7369+
src = op.ConstantOfShape(op.Shape(index), value=scalar_tensor)
7370+
return op.ScatterElements(self, index, src, axis=dim)
73367371

73377372

73387373
@torch_op("aten::scatter_add", trace_only=True)

onnxscript/irbuilder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def __str__(self):
214214

215215
def debug_print(self):
216216
if logger.isEnabledFor(logging.DEBUG):
217-
logger.debug("%s: %s", type(self), str(self))
217+
logger.debug("%s: %s", type(self), self)
218218

219219
def to_node_proto(self, node_name: str) -> onnx.NodeProto:
220220
n = helper.make_node(

requirements-dev.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ sphinx>=6
1717
myst_nb
1818
chardet
1919

20-
# Torch lib
21-
beartype!=0.16.0
22-
2320
# Testing
2421
expecttest==0.1.6
2522
hypothesis
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
onnx-weekly==1.20.0.dev20250901
1+
onnx-weekly==1.20.0.dev20251006
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# This file is auto updated by dependabot
22
lintrunner-adapters>=0.8.0
33
# RUFF, RUFF-FIX
4-
ruff==0.13.1
4+
ruff==0.13.2
55
# MYPY
66
mypy==1.10.1
7-
types-PyYAML==6.0.12.20250402
7+
types-PyYAML==6.0.12.20250915
88
# PYLINT
99
pylint==3.3.6
1010
# EDITORCONFIG-CHECKER
11-
editorconfig-checker==3.2.0
11+
editorconfig-checker==3.4.0

0 commit comments

Comments
 (0)