@@ -116,8 +116,8 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
116116 if use_multicolor_lines :
117117 if color .shape != grid .shape :
118118 raise ValueError ("If 'color' is given, it must match the shape of "
119- "'Grid (x, y)' " )
120- line_colors = []
119+ "the (x, y) grid " )
120+ line_colors = [[]] # Empty entry allows concatenation of zero arrays.
121121 color = np .ma .masked_invalid (color )
122122 else :
123123 line_kw ['color' ] = color
@@ -126,7 +126,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
126126 if isinstance (linewidth , np .ndarray ):
127127 if linewidth .shape != grid .shape :
128128 raise ValueError ("If 'linewidth' is given, it must match the "
129- "shape of 'Grid (x, y)' " )
129+ "shape of the (x, y) grid " )
130130 line_kw ['linewidth' ] = []
131131 else :
132132 line_kw ['linewidth' ] = linewidth
@@ -137,7 +137,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
137137
138138 # Sanity checks.
139139 if u .shape != grid .shape or v .shape != grid .shape :
140- raise ValueError ("'u' and 'v' must match the shape of 'Grid (x, y)' " )
140+ raise ValueError ("'u' and 'v' must match the shape of the (x, y) grid " )
141141
142142 u = np .ma .masked_invalid (u )
143143 v = np .ma .masked_invalid (v )
@@ -310,21 +310,22 @@ class Grid:
310310 """Grid of data."""
311311 def __init__ (self , x , y ):
312312
313- if x .ndim == 1 :
313+ if np .ndim ( x ) == 1 :
314314 pass
315- elif x .ndim == 2 :
316- x_row = x [0 , : ]
315+ elif np .ndim ( x ) == 2 :
316+ x_row = x [0 ]
317317 if not np .allclose (x_row , x ):
318318 raise ValueError ("The rows of 'x' must be equal" )
319319 x = x_row
320320 else :
321321 raise ValueError ("'x' can have at maximum 2 dimensions" )
322322
323- if y .ndim == 1 :
323+ if np .ndim ( y ) == 1 :
324324 pass
325- elif y .ndim == 2 :
326- y_col = y [:, 0 ]
327- if not np .allclose (y_col , y .T ):
325+ elif np .ndim (y ) == 2 :
326+ yt = np .transpose (y ) # Also works for nested lists.
327+ y_col = yt [0 ]
328+ if not np .allclose (y_col , yt ):
328329 raise ValueError ("The columns of 'y' must be equal" )
329330 y = y_col
330331 else :
0 commit comments