@@ -198,13 +198,14 @@ def _define_dim_flags(x, axis):
198198 """
199199 Define useful flags for the calculations in dpnp_matmul and dpnp_vecdot.
200200 x_is_1D: `x` is 1D array or inherently 1D (all dimensions are equal to one
201- except for one of them ), for instance, if x.shape = (1, 1, 1, 2),
202- then x_is_1D = True
201+ except for dimension at `axis` ), for instance, if x.shape = (1, 1, 1, 2),
202+ and axis=-1, then x_is_1D = True.
203203 x_is_2D: `x` is 2D array or inherently 2D (all dimensions are equal to one
204204 except for the last two of them), for instance, if x.shape = (1, 1, 3, 2),
205- then x_is_2D = True
205+ then x_is_2D = True.
206206 x_base_is_1D: `x` is 1D considering only its last two dimensions, for instance,
207- if x.shape = (3, 4, 1, 2), then x_base_is_1D = True
207+ if x.shape = (3, 4, 1, 2), then x_base_is_1D = True.
208+
208209 """
209210
210211 x_shape = x .shape
@@ -331,11 +332,11 @@ def _get_result_shape_vecdot(x1, x2, x1_ndim, x2_ndim):
331332
332333 if x1_ndim == 1 and x2_ndim == 1 :
333334 result_shape = ()
334- elif x1_is_1D :
335+ elif x1_is_1D and not x2_is_1D :
335336 result_shape = x2_shape [:- 1 ]
336- elif x2_is_1D :
337+ elif x2_is_1D and not x1_is_1D :
337338 result_shape = x1_shape [:- 1 ]
338- else : # at least 2D
339+ else :
339340 if x1_ndim != x2_ndim :
340341 diff = abs (x1_ndim - x2_ndim )
341342 if x1_ndim < x2_ndim :
0 commit comments