Skip to content

Commit 93cc7ec

Browse files
authored
fix: fix optree compatibility for multi-tree-map with None values (#195)
1 parent 86b167c commit 93cc7ec

24 files changed

+126
-124
lines changed

.github/workflows/lint.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ jobs:
2929
submodules: "recursive"
3030
fetch-depth: 1
3131

32-
- name: Set up Python 3.8
32+
- name: Set up Python 3.9
3333
uses: actions/setup-python@v4
3434
with:
35-
python-version: "3.8"
35+
python-version: "3.9"
3636
update-environment: true
3737

3838
- name: Setup CUDA Toolkit

.pre-commit-config.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ ci:
99
default_stages: [commit, push, manual]
1010
repos:
1111
- repo: https://github.com/pre-commit/pre-commit-hooks
12-
rev: v4.4.0
12+
rev: v4.5.0
1313
hooks:
1414
- id: check-symlinks
1515
- id: destroyed-symlinks
@@ -26,11 +26,11 @@ repos:
2626
- id: debug-statements
2727
- id: double-quote-string-fixer
2828
- repo: https://github.com/pre-commit/mirrors-clang-format
29-
rev: v16.0.6
29+
rev: v17.0.4
3030
hooks:
3131
- id: clang-format
3232
- repo: https://github.com/astral-sh/ruff-pre-commit
33-
rev: v0.0.287
33+
rev: v0.1.5
3434
hooks:
3535
- id: ruff
3636
args: [--fix, --exit-non-zero-on-fix]
@@ -39,11 +39,11 @@ repos:
3939
hooks:
4040
- id: isort
4141
- repo: https://github.com/psf/black
42-
rev: 23.7.0
42+
rev: 23.11.0
4343
hooks:
4444
- id: black-jupyter
4545
- repo: https://github.com/asottile/pyupgrade
46-
rev: v3.10.1
46+
rev: v3.15.0
4747
hooks:
4848
- id: pyupgrade
4949
args: [--py38-plus] # sync with requires-python
@@ -68,7 +68,7 @@ repos:
6868
^docs/source/conf.py$
6969
)
7070
- repo: https://github.com/codespell-project/codespell
71-
rev: v2.2.5
71+
rev: v2.2.6
7272
hooks:
7373
- id: codespell
7474
additional_dependencies: [".[toml]"]

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717

1818
### Changed
1919

20-
-
20+
- Set minimal C++ standard to C++17 by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195).
2121

2222
### Fixed
2323

24-
-
24+
- Fix `optree` compatibility for multi-tree-map with `None` values by [@XuehaiPan](https://github.com/XuehaiPan) in [#195](https://github.com/metaopt/torchopt/pull/195).
2525

2626
### Removed
2727

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ cmake_minimum_required(VERSION 3.11) # for FetchContent
1717
project(torchopt LANGUAGES CXX)
1818

1919
include(FetchContent)
20-
set(PYBIND11_VERSION v2.10.3)
20+
set(PYBIND11_VERSION v2.11.1)
2121

2222
if(NOT CMAKE_BUILD_TYPE)
2323
set(CMAKE_BUILD_TYPE Release)
2424
endif()
2525

26-
set(CMAKE_CXX_STANDARD 14)
26+
set(CMAKE_CXX_STANDARD 17)
2727
set(CMAKE_CXX_STANDARD_REQUIRED ON)
2828

2929
find_package(Threads REQUIRED) # -pthread

conda-recipe.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ dependencies:
7777
- hunspell-en
7878
- myst-nb
7979
- ipykernel
80-
- pandoc
8180
- docutils
8281

8382
# Testing

docs/conda-recipe.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,5 +67,4 @@ dependencies:
6767
- hunspell-en
6868
- myst-nb
6969
- ipykernel
70-
- pandoc
7170
- docutils

docs/requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@ torch >= 1.13
44

55
--requirement ../requirements.txt
66

7-
sphinx >= 5.2.1
7+
sphinx >= 5.2.1, < 7.0.0a0
8+
sphinxcontrib-bibtex >= 2.4
9+
sphinx-autodoc-typehints >= 1.20
10+
myst-nb >= 0.15
11+
812
sphinx-autoapi
913
sphinx-autobuild
1014
sphinx-copybutton
1115
sphinx-rtd-theme
1216
sphinxcontrib-katex
13-
sphinxcontrib-bibtex
14-
sphinx-autodoc-typehints >= 1.19.2
1517
IPython
1618
ipykernel
17-
pandoc
18-
myst-nb
1919
docutils
2020
matplotlib

torchopt/alias/sgd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
__all__ = ['sgd']
4545

4646

47+
# pylint: disable-next=too-many-arguments
4748
def sgd(
4849
lr: ScalarOrSchedule,
4950
momentum: float = 0.0,

torchopt/alias/utils.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -108,19 +108,21 @@ def update_fn(
108108

109109
if inplace:
110110

111-
def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
111+
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
112+
if g is None:
113+
return g
112114
if g.requires_grad:
113115
return g.add_(p, alpha=weight_decay)
114116
return g.add_(p.data, alpha=weight_decay)
115117

116-
updates = tree_map_(f, updates, params)
118+
tree_map_(f, params, updates)
117119

118120
else:
119121

120-
def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
121-
return g.add(p, alpha=weight_decay)
122+
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
123+
return g.add(p, alpha=weight_decay) if g is not None else g
122124

123-
updates = tree_map(f, updates, params)
125+
updates = tree_map(f, params, updates)
124126

125127
return updates, state
126128

@@ -139,7 +141,7 @@ def update_fn(
139141
def f(g: torch.Tensor) -> torch.Tensor:
140142
return g.neg_()
141143

142-
updates = tree_map_(f, updates)
144+
tree_map_(f, updates)
143145

144146
else:
145147

@@ -166,19 +168,21 @@ def update_fn(
166168

167169
if inplace:
168170

169-
def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
171+
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
172+
if g is None:
173+
return g
170174
if g.requires_grad:
171175
return g.neg_().add_(p, alpha=weight_decay)
172176
return g.neg_().add_(p.data, alpha=weight_decay)
173177

174-
updates = tree_map_(f, updates, params)
178+
tree_map_(f, params, updates)
175179

176180
else:
177181

178-
def f(g: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
179-
return g.neg().add_(p, alpha=weight_decay)
182+
def f(p: torch.Tensor, g: torch.Tensor | None) -> torch.Tensor | None:
183+
return g.neg().add_(p, alpha=weight_decay) if g is not None else g
180184

181-
updates = tree_map(f, updates, params)
185+
updates = tree_map(f, params, updates)
182186

183187
return updates, state
184188

torchopt/distributed/api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def sum_reducer(results: Iterable[torch.Tensor]) -> torch.Tensor:
271271
return torch.sum(torch.stack(tuple(results), dim=0), dim=0)
272272

273273

274+
# pylint: disable-next=too-many-arguments
274275
def remote_async_call(
275276
func: Callable[..., T],
276277
*,
@@ -328,6 +329,7 @@ def remote_async_call(
328329
return future
329330

330331

332+
# pylint: disable-next=too-many-arguments
331333
def remote_sync_call(
332334
func: Callable[..., T],
333335
*,

0 commit comments

Comments
 (0)