Skip to content

Commit 9a28a08

Browse files
authored
Merge branch 'microsoft:main' into conv_and_hardswish_fusion
2 parents 6216821 + 9036fab commit 9a28a08

File tree

98 files changed

+5437
-1618
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

98 files changed

+5437
-1618
lines changed

.github/workflows/pages.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ jobs:
4242
- name: Build documentation
4343
run: python -m sphinx docs dist/html
4444
- name: Upload documentation archive
45-
uses: actions/upload-pages-artifact@v3
45+
uses: actions/upload-pages-artifact@v4
4646
with:
4747
path: 'dist/html'
4848
- name: Deploy to GitHub Pages

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.4.0
1+
0.4.1

noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
"packaging",
4343
"protobuf",
4444
)
45-
ONNX_IR = "onnx_ir==0.1.3"
45+
ONNX_IR = "onnx_ir==0.1.7"
4646
ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir"
4747

4848

onnxscript/_framework_apis/torch_2_8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
"""Stable APIs for PyTorch 2.7."""
3+
"""Stable APIs for PyTorch 2.8."""
44

55
from __future__ import annotations
66

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Stable APIs for PyTorch 2.9."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"check_model",
9+
"convert_version",
10+
"get_torchlib_ops",
11+
"optimize",
12+
"save_model_with_external_data",
13+
]
14+
15+
from typing import TYPE_CHECKING
16+
17+
from onnxscript import version_converter
18+
from onnxscript._framework_apis.torch_2_8 import (
19+
check_model,
20+
get_torchlib_ops,
21+
optimize,
22+
save_model_with_external_data,
23+
)
24+
25+
if TYPE_CHECKING:
26+
import onnx_ir as ir
27+
28+
29+
def convert_version(model: ir.Model, target_version: int) -> ir.Model:
30+
"""Convert the model to the specified ONNX opset version.
31+
32+
Starting from PyTorch 2.9, down conversion is turned on and supported.
33+
"""
34+
version_converter.convert_version(model, target_version, fallback=True)
35+
return model

onnxscript/backend/onnx_export.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
_SINGLE_INDENT = " "
1515

16+
_SMALL_TENSOR_SIZE = 4
17+
1618
kwlist = {
1719
"False",
1820
"None",
@@ -119,7 +121,7 @@ def renamer(name):
119121

120122
def _translate_type(onnx_type):
121123
"""Converts a onnx type into a type defined by *onnxscript*."""
122-
return onnxscript.onnx_types.onnx_type_to_onnxscript_repr(onnx_type)
124+
return onnxscript.onnx_types.onnx_type_to_onnxscript_repr(onnx_type, reversible=False)
123125

124126

125127
def _translate_signature(inputs, outputs):
@@ -350,25 +352,33 @@ def _translate_graph_body(self, graph, opsets, indent=0):
350352
if hasattr(graph, "initializer"):
351353
for init in graph.initializer:
352354
if self.skip_initializers:
353-
init_py_name = self._translate_onnx_var(init.name)
354-
if init_py_name in self.skipped_initializers:
355-
raise RuntimeError(
356-
f"Initializer {init.name!r} is already present in skipped_initializers."
357-
)
358-
self.skipped_initializers[init_py_name] = init
359-
continue
355+
size = 1
356+
for d in init.dims:
357+
size *= d
358+
if size > _SMALL_TENSOR_SIZE:
359+
init_py_name = self._translate_onnx_var(init.name)
360+
if init_py_name in self.skipped_initializers:
361+
raise RuntimeError(
362+
f"Initializer {init.name!r} is already present in skipped_initializers."
363+
)
364+
self.skipped_initializers[init_py_name] = init
365+
continue
360366
node = onnx.helper.make_node( # noqa: TID251
361367
"Constant",
362368
[],
363369
[self._translate_onnx_var(init.name)], # type: ignore[list-item]
364370
value=init,
365371
)
366-
code.append(self._translate_node(node, opsets, indent=indent))
372+
pyinit = self._translate_node(node, opsets, indent=indent)
373+
if pyinit:
374+
code.append(pyinit)
367375
if hasattr(graph, "sparse_initializer") and len(graph.sparse_initializer) > 0:
368376
raise NotImplementedError("Unable to convert sparse_initilizer into python.")
369377
for node in graph.node:
370378
pynode = self._translate_node(node, opsets, indent=indent)
371379
if pynode:
380+
if node.name:
381+
pynode += f" # {node.name}"
372382
code.append(pynode)
373383

374384
final = "\n".join(code)
@@ -418,7 +428,8 @@ def _translate_attributes(self, node):
418428
def _translate_if(self, node, opsets, indent=0):
419429
"""Translates a node If into python."""
420430
sindent = _SINGLE_INDENT * indent
421-
code = [f"{sindent}if {node.input[0]}:"]
431+
cond = self._translate_onnx_var_ref(node.input[0])
432+
code = [f"{sindent}if {cond}:"]
422433
if len(node.attribute) != 2:
423434
raise RuntimeError(
424435
f"Node {node.op_type!r} expected two attributes not {len(node.attribute)}."
@@ -502,17 +513,21 @@ def _translate_loop(self, node, opsets, indent=0):
502513

503514
rows.extend(self._emit_assign(formal_ins, actual_ins, indent))
504515

516+
if node.name:
517+
node_name = " # " + node.name
518+
else:
519+
node_name = ""
505520
if use_iter_var and not use_loop_cond:
506-
rows.append(f"{sindent}for {iter_var} in range({n_iter}):")
521+
rows.append(f"{sindent}for {iter_var} in range({n_iter}):{node_name}")
507522
# The following is a hacky way to suppress the generation of
508523
# "cond_out = cond_in", which ONNX forces for a FOR loop.
509524
# TODO: a cleaner solution for this.
510525
self._name_remappings[-1][cond_out] = self._translate_onnx_var(cond_in)
511526
elif not use_iter_var and use_loop_cond:
512-
rows.append(f"{sindent}while {py_cond}:")
527+
rows.append(f"{sindent}while {py_cond}:{node_name}")
513528
elif use_iter_var and use_loop_cond:
514529
# TODO: This needs fixing
515-
rows.append(f"{sindent}for {iter_var} in range({n_iter}):")
530+
rows.append(f"{sindent}for {iter_var} in range({n_iter}):{node_name}")
516531
rows.append(f"{sindent}{_SINGLE_INDENT}if not {py_cond}:")
517532
rows.append(f"{sindent}{_SINGLE_INDENT * 2}break")
518533
else:
@@ -734,11 +749,13 @@ def _substitute_initializers(
734749

735750
def generate_rand(name: str, value: TensorProto) -> str:
736751
shape = ",".join(str(d) for d in value.dims)
737-
if value.data_type != TensorProto.FLOAT:
738-
raise NotImplementedError(
739-
f"Unable to generate random initializer for data type {value.data_type}."
740-
)
741-
return f"{__}{name} = np.random.rand({shape}).astype(np.float32)"
752+
if value.data_type == TensorProto.FLOAT:
753+
return f"{__}{name} = np.random.rand({shape}).astype(np.float32)"
754+
if value.data_type == TensorProto.INT8:
755+
return f"{__}{name} = np.random.randint(-128, 127, size=({shape},), dtype=np.int8)"
756+
raise NotImplementedError(
757+
f"Unable to generate random initializer for data type {value.data_type}."
758+
)
742759

743760
random_initializer_values = "\n".join(
744761
generate_rand(key, value) for key, value in self.skipped_initializers.items()

onnxscript/backend/onnx_export_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True):
9999
"^test_resize_upsample_scales_linear_half_pixel_symmetric",
100100
"cannot import module, import_module does not work",
101101
),
102+
# tests are too unstable on Windows, not always the same ones are failing.
103+
skip("test_", "cannot import module"),
102104
)
103105

104106

onnxscript/function_libs/torch_lib/ops/common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value"
66
from __future__ import annotations
77

8+
from collections.abc import Sequence
9+
810
import numpy.typing as npt
911
import onnx
1012

@@ -78,3 +80,22 @@ def constant(
7880
A constant node.
7981
"""
8082
return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype)))
83+
84+
85+
def merge_dims(dims: Sequence[int | INT64]) -> INT64:
86+
"""Concatenate dimensions into a single value."""
87+
88+
if not dims:
89+
return op.Constant(value_ints=ir.AttrInt64s("value_ints", []))
90+
91+
neg_one_1d = op.Constant(value_ints=ir.AttrInt64s("value_ints", [-1]))
92+
93+
result_dims = [
94+
op.Constant(value_ints=[d]) if isinstance(d, int) else op.Reshape(d, neg_one_1d)
95+
for d in dims
96+
]
97+
98+
# Set the output type to INT64 so op.Concat can be used
99+
for dim in result_dims:
100+
dim.dtype = ir.DataType.INT64
101+
return op.Concat(*result_dims, axis=0)

0 commit comments

Comments
 (0)