We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 9031209 commit 1dc8650Copy full SHA for 1dc8650
src/phyjax2d/impl.py
@@ -77,7 +77,7 @@ def __truediv__(self, o: float | jax.Array) -> Self:
77
return jax.tree_util.tree_map(lambda x: x / o, self)
78
79
@jax.jit
80
- def get_slice(self, index: INDEX) -> Self:
+ def get_slice(self, index: Sequence[int] | int | None) -> Self:
81
return jax.tree_util.tree_map(lambda x: x[index], self)
82
83
def split(self, split_index: int) -> tuple[Self, Self]:
@@ -900,8 +900,8 @@ def set_ignore_flags_by_indices(
900
self,
901
target_n1: str,
902
target_n2: str,
903
- n1_idx: INDEX,
904
- n2_idx: INDEX,
+ n1_idx: Sequence[int] | int | None,
+ n2_idx: Sequence[int] | int | None,
905
) -> None:
906
start = 0
907
for n1, n2 in _CONTACT_FUNCTIONS.keys():
0 commit comments