@@ -65,7 +65,9 @@ def expand_dims(
6565 a : array
6666 axis : int or tuple of ints
6767 Position(s) in the expanded axes where the new axis (or axes) is/are placed.
68- If multiple positions are provided, they should be unique.
68+ If multiple positions are provided, they should be unique (note that a position
69+ given by a positive index could also be referred to by a negative index -
70+ that will also result in an error).
6971 Default: ``(0,)``.
7072 xp : array_namespace
7173 The standard-compatible namespace for `a`.
@@ -114,16 +116,19 @@ def expand_dims(
114116 """
115117 if not isinstance (axis , tuple ):
116118 axis = (axis ,)
117- if len (set (axis )) != len (axis ):
118- err_msg = "Duplicate dimensions specified in `axis`."
119- raise ValueError (err_msg )
120119 ndim = a .ndim + len (axis )
121120 if axis != () and (min (axis ) < - ndim or max (axis ) >= ndim ):
122121 err_msg = (
123122 f"a provided axis position is out of bounds for array of dimension { a .ndim } "
124123 )
125124 raise IndexError (err_msg )
126125 axis = tuple (dim % ndim for dim in axis )
126+ if len (set (axis )) != len (axis ):
127+ err_msg = "Duplicate dimensions specified in `axis`."
128+ raise ValueError (err_msg )
129+ if len (set (axis )) != len (axis ):
130+ err_msg = "Duplicate dimensions specified in `axis`."
131+ raise ValueError (err_msg )
127132 for i in sorted (axis ):
128133 a = xp .expand_dims (a , axis = i )
129134 return a
0 commit comments