Skip to content

Commit 0a67bfd

Browse files
committed
resolve conflict
2 parents 2b6775a + ad83914 commit 0a67bfd

File tree

12 files changed

+95
-40
lines changed

12 files changed

+95
-40
lines changed

.github/workflows/main.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ jobs:
8383
token: ${{ secrets.CODECOV_TOKEN }}
8484
- name: Upload torchlib error reports
8585
if: always()
86-
uses: actions/upload-artifact@v4
86+
uses: actions/upload-artifact@v5
8787
with:
8888
name: Error reports (${{ matrix.name }}-${{ matrix.os }})
8989
path: error_reports

onnxscript/_framework_apis/torch_2_5.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
]
1414

1515
import dataclasses
16+
import importlib.util
1617
import os
1718
import pathlib
1819
from typing import Callable
@@ -63,7 +64,9 @@ def check_model(model: ir.Model) -> None:
6364
del model # Unused yet
6465

6566

66-
def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike) -> None:
67+
def save_model_with_external_data(
68+
model: ir.Model, model_path: str | os.PathLike, verbose: bool = False
69+
) -> None:
6770
"""Save the model with external data. The model is unchanged after saving."""
6871

6972
# TODO(#1835): Decide if we want to externalize large attributes as well
@@ -78,7 +81,31 @@ def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike
7881
destination_path = pathlib.Path(model_path)
7982
data_path = f"{destination_path.name}.data"
8083

81-
ir.save(model, model_path, external_data=data_path)
84+
# Show a progress bar if verbose is True and tqdm is installed
85+
use_tqdm = verbose and importlib.util.find_spec("tqdm") is not None
86+
87+
if use_tqdm:
88+
import tqdm # pylint: disable=import-outside-toplevel
89+
90+
with tqdm.tqdm() as pbar:
91+
total_set = False
92+
93+
def callback(
94+
tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo
95+
) -> None:
96+
nonlocal total_set
97+
if not total_set:
98+
pbar.total = metadata.total
99+
total_set = True
100+
101+
pbar.update()
102+
pbar.set_description(
103+
f"Saving {tensor.name} ({tensor.dtype.short_name()}, {tensor.shape}) at offset {metadata.offset}"
104+
)
105+
106+
ir.save(model, model_path, external_data=data_path, callback=callback)
107+
else:
108+
ir.save(model, model_path, external_data=data_path)
82109

83110

84111
def get_torchlib_ops() -> list[_OnnxFunctionMeta]:

onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import parameterized
1313

1414
import onnxscript
15-
import onnxscript.function_libs.torch_lib.ops # Import to populate registry
15+
import onnxscript.function_libs.torch_lib.ops # Import to populate registry # noqa: F401
1616
from onnxscript.function_libs.tools.torch_lib import deduce_type_constraints
1717
from onnxscript.function_libs.torch_lib import registration
1818

onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from typing import Any, Dict, List, Sequence
1717

1818
import torch
19-
import torchgen.gen
2019
import torchgen.model
2120
from torch._ops import _OpNamespace
2221
from torchgen.model import FunctionSchema

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8533,6 +8533,14 @@ def aten_trunc(self: TFloat) -> TFloat:
85338533
return op.Floor(op.Abs(self)) * op.Sign(self)
85348534

85358535

8536+
@torch_op("math::trunc", trace_only=True)
8537+
def python_math_trunc(self: TFloat) -> TInt:
8538+
"""trunc(Tensor self) -> Tensor"""
8539+
# NOTE: This is used in SymInt/SymBool/SymFloat context, so
8540+
# we don't expect overflow to happen here.
8541+
return op.Cast(self, to=INT64.dtype)
8542+
8543+
85368544
@torch_op("aten::type_as", trace_only=True)
85378545
def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2:
85388546
"""type_as(Tensor self, Tensor other) -> Tensor"""

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,33 @@ def _adjust_attributes_of_avg_pool(
114114
return (kernel_shape, strides, pads)
115115

116116

117+
def _aten_avg_pool_onnx(
118+
self: TFloat,
119+
kernel_shape: Sequence[int],
120+
strides: Sequence[int],
121+
pads: Sequence[int],
122+
ceil_mode: bool,
123+
count_include_pad: bool,
124+
) -> TFloat:
125+
self_rank_is_unbatched_rank = len(self.shape) == len(kernel_shape) + 1
126+
if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1
127+
self = op.Unsqueeze(self, [0])
128+
129+
result = op.AveragePool(
130+
self,
131+
ceil_mode=ceil_mode,
132+
count_include_pad=count_include_pad,
133+
kernel_shape=kernel_shape,
134+
pads=pads,
135+
strides=strides,
136+
)
137+
138+
if self_rank_is_unbatched_rank:
139+
result = op.Squeeze(result, [0])
140+
141+
return result
142+
143+
117144
@torch_op("aten::avg_pool1d", trace_only=True)
118145
def aten_avg_pool1d(
119146
self: TFloat,
@@ -134,16 +161,7 @@ def aten_avg_pool1d(
134161
expand_size, kernel_size, stride, padding
135162
)
136163

137-
result = op.AveragePool(
138-
self,
139-
ceil_mode=ceil_mode,
140-
count_include_pad=count_include_pad,
141-
kernel_shape=kernel_shape,
142-
pads=pads,
143-
strides=strides,
144-
)
145-
146-
return result
164+
return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad)
147165

148166

149167
@torch_op("aten::avg_pool2d", trace_only=True)
@@ -167,15 +185,6 @@ def aten_avg_pool2d(
167185
expand_size, kernel_size, stride, padding
168186
)
169187

170-
result = op.AveragePool(
171-
self,
172-
ceil_mode=ceil_mode,
173-
count_include_pad=count_include_pad,
174-
kernel_shape=kernel_shape,
175-
pads=pads,
176-
strides=strides,
177-
)
178-
179188
# TODO: if want to support divisor_override argument, need to op.Mul(result, mask)
180189
# mask = [
181190
# 1, 2, 3, S,..3, 2, 1
@@ -189,7 +198,7 @@ def aten_avg_pool2d(
189198
# S is stride size, in this case S=4,
190199
# S may dup lot of times according to the image size
191200

192-
return result
201+
return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad)
193202

194203

195204
def aten_avg_pool2d_backward(
@@ -228,15 +237,6 @@ def aten_avg_pool3d(
228237
expand_size, kernel_size, stride, padding
229238
)
230239

231-
result = op.AveragePool(
232-
self,
233-
kernel_shape=kernel_shape,
234-
strides=strides,
235-
pads=pads,
236-
count_include_pad=count_include_pad,
237-
ceil_mode=ceil_mode,
238-
)
239-
240240
# TODO: if want to support divisor_override argument, need to op.Mul(result, mask)
241241
# mask = [
242242
# 1, 2, 3, S,..3, 2, 1
@@ -250,7 +250,7 @@ def aten_avg_pool3d(
250250
# S is stride size, in this case S=4,
251251
# S may dup lot of times according to the image size
252252

253-
return result
253+
return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad)
254254

255255

256256
def aten_avg_pool3d_backward(

onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import unittest
44

55
import numpy as np
6-
import onnx
76
import onnx.checker
87
import onnx.shape_inference
98
import onnxruntime

onnxscript/rewriter/rules/fusion/_layer_norm_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import onnx_ir as ir
77

8-
import onnxscript
98
import onnxscript.optimizer
109
import onnxscript.rewriter.testing
1110
from onnxscript import FLOAT, OnnxFunction, script
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
onnx-weekly==1.20.0.dev20251006
1+
onnx-weekly==1.20.0.dev20251027

requirements/lintrunner/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is auto updated by dependabot
22
lintrunner-adapters>=0.8.0
33
# RUFF, RUFF-FIX
4-
ruff==0.13.2
4+
ruff==0.14.2
55
# MYPY
66
mypy==1.10.1
77
types-PyYAML==6.0.12.20250915

0 commit comments

Comments
 (0)