Skip to content

Commit ceb6869

Browse files
committed
refactor
1 parent b82bb62 commit ceb6869

File tree

4 files changed

+13
-10
lines changed

4 files changed

+13
-10
lines changed

AGENTS.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ onnx2fx/
4040
│ ├── analyze.py # Model analysis utilities
4141
│ ├── attributes.py # ONNX attribute parsing
4242
│ ├── dtype.py # ONNX to PyTorch dtype mapping
43+
│ ├── external_data.py # External data utilities
4344
│ ├── names.py # Name sanitization utilities
4445
│ ├── op_helpers.py # Op helper utilities
4546
│ └── training.py # Training utilities (make_trainable)
@@ -63,7 +64,6 @@ Core Requirements:
6364
- Python >= 3.11
6465
- PyTorch >= 2.9.0
6566
- ONNX >= 1.19.1
66-
- onnxscript >= 0.3.0
6767

6868
Development Tools:
6969
- uv (recommended for development)
@@ -147,7 +147,7 @@ def bias_gelu(builder, node):
147147
- `builder.has_value(name)` - Check if value exists in environment
148148
- `builder.call_function(func, args, kwargs)` - Create function call node
149149
- `builder.call_module(module_name, args, kwargs)` - Create module call node
150-
- `builder.add_submodule(name, module)` - Register a submodule (returns safe name)
150+
- `builder.register_submodule(name, module)` - Register a submodule (returns safe name)
151151
- `builder.opset_version` - Get current opset version for default domain
152152
- `builder.get_opset_version(domain)` - Get opset version for specific domain
153153

@@ -161,7 +161,7 @@ For parsing ONNX node attributes, use functions from `onnx2fx.utils.attributes`:
161161
### Public API
162162

163163
#### Core Functions
164-
- `convert(model)` - Convert ONNX model to FX GraphModule
164+
- `convert(model, *, base_dir=None, memmap_external_data=False)` - Convert ONNX model to FX GraphModule
165165
- `make_trainable(module)` - Convert buffers to trainable parameters for training
166166

167167
#### Model Analysis

README.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Yet another ONNX to PyTorch FX converter.
99
## Features
1010

1111
- **Simple API**: Convert ONNX models with a single function call
12-
- **Extensive Operator Support**: 170+ ONNX operators including standard and Microsoft domain operators
12+
- **Extensive Operator Support**: Wide ONNX operator coverage including standard and Microsoft domain operators
1313
- **Multi-Opset Version Support**: Automatic selection of version-specific operator handlers based on model opset
1414
- **Custom Operator Registration**: Easily extend support for unsupported or custom ONNX operators
1515
- **PyTorch FX Output**: Get a `torch.fx.GraphModule` for easy inspection, optimization, and compilation
@@ -36,22 +36,21 @@ The following models have been tested and verified to work with onnx2fx:
3636
- Python >= 3.11
3737
- PyTorch >= 2.9.0
3838
- ONNX >= 1.19.1
39-
- onnxscript >= 0.3.0
4039

4140
### From Source
4241

4342
```bash
4443
git clone https://github.com/mshr-h/onnx2fx.git
4544
cd onnx2fx
46-
pip install .
45+
uv sync
4746
```
4847

4948
### Development Installation
5049

5150
```bash
5251
git clone https://github.com/mshr-h/onnx2fx.git
5352
cd onnx2fx
54-
pip install -e ".[dev]"
53+
uv sync --dev
5554
```
5655

5756
## Quick Start
@@ -70,6 +69,10 @@ fx_module = convert("model.onnx")
7069
onnx_model = onnx.load("model.onnx")
7170
fx_module = convert(onnx_model)
7271

72+
# For models with external data, you can pass base_dir.
73+
# memmap_external_data avoids loading external data into memory.
74+
fx_module = convert("model.onnx", base_dir="/path/to/model_dir", memmap_external_data=True)
75+
7376
# Run inference
7477
input_tensor = torch.randn(1, 3, 224, 224)
7578
output = fx_module(input_tensor)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ authors = [
99
requires-python = ">=3.11"
1010
dependencies = [
1111
"onnx>=1.19.1",
12-
"onnxscript>=0.5.7",
1312
"torch>=2.9.0",
1413
]
1514

@@ -20,6 +19,7 @@ build-backend = "uv_build"
2019
[dependency-groups]
2120
dev = [
2221
"huggingface-hub>=0.36.0",
22+
"onnxscript>=0.5.7",
2323
"onnxruntime>=1.23.2",
2424
"pre-commit>=4.5.1",
2525
"pytest>=9.0.2",

uv.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)