Skip to content

Commit 088dfb2

Browse files
authored
Make OrderedSet public (#356)
* Import OrderedSet in _transform/__init__.py * Import OrderedSet directly from _transform rather than from _transform._ordered_set
1 parent cd606bc commit 088dfb2

File tree

11 files changed

+20
-14
lines changed

11 files changed

+20
-14
lines changed

src/torchjd/_autojac/_backward.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,16 @@
44

55
from torchjd.aggregation import Aggregator
66

7-
from ._transform import Accumulate, Aggregate, Diagonalize, EmptyTensorDict, Init, Jac, Transform
8-
from ._transform._ordered_set import OrderedSet
7+
from ._transform import (
8+
Accumulate,
9+
Aggregate,
10+
Diagonalize,
11+
EmptyTensorDict,
12+
Init,
13+
Jac,
14+
OrderedSet,
15+
Transform,
16+
)
917
from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors
1018

1119

src/torchjd/_autojac/_mtl_backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
Gradients,
1313
Init,
1414
Jac,
15+
OrderedSet,
1516
Select,
1617
Stack,
1718
Transform,
1819
)
19-
from ._transform._ordered_set import OrderedSet
2020
from ._utils import as_checked_ordered_set, check_optional_positive_chunk_size, get_leaf_tensors
2121

2222

src/torchjd/_autojac/_transform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from ._grad import Grad
66
from ._init import Init
77
from ._jac import Jac
8+
from ._ordered_set import OrderedSet
89
from ._select import Select
910
from ._stack import Stack
1011
from ._tensor_dict import (

src/torchjd/_autojac/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import Tensor
55
from torch.autograd.graph import Node
66

7-
from ._transform._ordered_set import OrderedSet
7+
from ._transform import OrderedSet
88

99

1010
def check_optional_positive_chunk_size(parallel_chunk_size: int | None) -> None:

tests/unit/autojac/_transform/test_aggregate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
GradientVectors,
1111
JacobianMatrices,
1212
Jacobians,
13+
OrderedSet,
1314
RequirementError,
1415
)
1516
from torchjd._autojac._transform._aggregate import _AggregateMatrices, _Matrixify, _Reshape
16-
from torchjd._autojac._transform._ordered_set import OrderedSet
1717
from torchjd.aggregation import Random
1818

1919
from ._dict_assertions import assert_tensor_dicts_are_close

tests/unit/autojac/_transform/test_diagonalize.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import torch
22
from pytest import raises
33

4-
from torchjd._autojac._transform import Diagonalize, Gradients, RequirementError
5-
from torchjd._autojac._transform._ordered_set import OrderedSet
4+
from torchjd._autojac._transform import Diagonalize, Gradients, OrderedSet, RequirementError
65

76
from ._dict_assertions import assert_tensor_dicts_are_close
87

tests/unit/autojac/_transform/test_grad.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import torch
22
from pytest import raises
33

4-
from torchjd._autojac._transform import Grad, Gradients, RequirementError
5-
from torchjd._autojac._transform._ordered_set import OrderedSet
4+
from torchjd._autojac._transform import Grad, Gradients, OrderedSet, RequirementError
65

76
from ._dict_assertions import assert_tensor_dicts_are_close
87

tests/unit/autojac/_transform/test_interactions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
Init,
1313
Jac,
1414
Jacobians,
15+
OrderedSet,
1516
RequirementError,
1617
Select,
1718
Stack,
1819
TensorDict,
1920
)
20-
from torchjd._autojac._transform._ordered_set import OrderedSet
2121

2222
from ._dict_assertions import assert_tensor_dicts_are_close
2323

tests/unit/autojac/_transform/test_jac.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import torch
22
from pytest import mark, raises
33

4-
from torchjd._autojac._transform import Jac, Jacobians, RequirementError
5-
from torchjd._autojac._transform._ordered_set import OrderedSet
4+
from torchjd._autojac._transform import Jac, Jacobians, OrderedSet, RequirementError
65

76
from ._dict_assertions import assert_tensor_dicts_are_close
87

tests/unit/autojac/test_backward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from torchjd import backward
77
from torchjd._autojac._backward import _create_transform
8-
from torchjd._autojac._transform._ordered_set import OrderedSet
8+
from torchjd._autojac._transform import OrderedSet
99
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, Sum, UPGrad
1010

1111

0 commit comments

Comments
 (0)