|
9 | 9 |
|
10 | 10 | import numpy as np |
11 | 11 | import pytest |
12 | | -import xarray as xr |
13 | 12 | from xarray import DataArray |
14 | 13 | from xarray import concat as xr_concat |
15 | 14 |
|
16 | | -from pytensor.tensor import scalar |
17 | 15 | from pytensor.xtensor.shape import ( |
18 | 16 | concat, |
19 | | - expand_dims, |
20 | 17 | squeeze, |
21 | 18 | stack, |
22 | 19 | transpose, |
@@ -269,156 +266,6 @@ def test_concat_scalar(): |
269 | 266 | xr_assert_allclose(res, expected_res) |
270 | 267 |
|
271 | 268 |
|
272 | | -def test_expand_dims_explicit(): |
273 | | - """Test expand_dims with explicitly named dimensions and sizes.""" |
274 | | - |
275 | | - # 1D case |
276 | | - x = xtensor("x", dims=("city",), shape=(3,)) |
277 | | - y = expand_dims(x, "country") |
278 | | - fn = xr_function([x], y) |
279 | | - x_xr = xr_arange_like(x) |
280 | | - xr_assert_allclose(fn(x_xr), x_xr.expand_dims("country")) |
281 | | - |
282 | | - # 2D case |
283 | | - x = xtensor("x", dims=("city", "year"), shape=(2, 2)) |
284 | | - y = expand_dims(x, "country") |
285 | | - fn = xr_function([x], y) |
286 | | - xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country")) |
287 | | - |
288 | | - # 3D case |
289 | | - x = xtensor("x", dims=("city", "year", "month"), shape=(2, 2, 2)) |
290 | | - y = expand_dims(x, "country") |
291 | | - fn = xr_function([x], y) |
292 | | - xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("country")) |
293 | | - |
294 | | - # Prepending various dims |
295 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
296 | | - for new_dim in ("x", "y", "z"): |
297 | | - y = expand_dims(x, new_dim) |
298 | | - assert y.type.dims == (new_dim, "a", "b") |
299 | | - assert y.type.shape == (1, 2, 3) |
300 | | - |
301 | | - # Explicit size=1 behaves like default |
302 | | - y1 = expand_dims(x, "batch", size=1) |
303 | | - y2 = expand_dims(x, "batch") |
304 | | - fn1 = xr_function([x], y1) |
305 | | - fn2 = xr_function([x], y2) |
306 | | - x_test = xr_arange_like(x) |
307 | | - xr_assert_allclose(fn1(x_test), fn2(x_test)) |
308 | | - |
309 | | - # Scalar expansion |
310 | | - x = xtensor("x", dims=(), shape=()) |
311 | | - y = expand_dims(x, "batch") |
312 | | - assert y.type.dims == ("batch",) |
313 | | - assert y.type.shape == (1,) |
314 | | - fn = xr_function([x], y) |
315 | | - xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x).expand_dims("batch")) |
316 | | - |
317 | | - # Static size > 1: broadcast |
318 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
319 | | - y = expand_dims(x, "batch", size=4) |
320 | | - fn = xr_function([x], y) |
321 | | - expected = xr.DataArray( |
322 | | - np.broadcast_to(xr_arange_like(x).data, (4, 2, 3)), |
323 | | - dims=("batch", "a", "b"), |
324 | | - coords={"a": xr_arange_like(x).coords["a"], "b": xr_arange_like(x).coords["b"]}, |
325 | | - ) |
326 | | - xr_assert_allclose(fn(xr_arange_like(x)), expected) |
327 | | - |
328 | | - # Insert new dim between existing dims |
329 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
330 | | - y = expand_dims(x, "new") |
331 | | - # Insert new dim between a and b: ("a", "new", "b") |
332 | | - y = transpose(y, "a", "new", "b") |
333 | | - fn = xr_function([x], y) |
334 | | - x_test = xr_arange_like(x) |
335 | | - expected = x_test.expand_dims("new").transpose("a", "new", "b") |
336 | | - xr_assert_allclose(fn(x_test), expected) |
337 | | - |
338 | | - # Expand with multiple dims |
339 | | - x = xtensor("x", dims=(), shape=()) |
340 | | - y = expand_dims(expand_dims(x, "a"), "b") |
341 | | - fn = xr_function([x], y) |
342 | | - expected = xr_arange_like(x).expand_dims("a").expand_dims("b") |
343 | | - xr_assert_allclose(fn(xr_arange_like(x)), expected) |
344 | | - |
345 | | - |
346 | | -def test_expand_dims_implicit(): |
347 | | - """Test expand_dims with default or symbolic sizes and dim=None.""" |
348 | | - |
349 | | - # Symbolic size=1: same as default |
350 | | - size_sym_1 = scalar("size_sym_1", dtype="int64") |
351 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
352 | | - y = expand_dims(x, "batch", size=size_sym_1) |
353 | | - fn = xr_function([x, size_sym_1], y, on_unused_input="ignore") |
354 | | - expected = xr_arange_like(x).expand_dims("batch") |
355 | | - xr_assert_allclose(fn(xr_arange_like(x), 1), expected) |
356 | | - |
357 | | - # Symbolic size > 1 (but expand only adds dim=1) |
358 | | - size_sym_4 = scalar("size_sym_4", dtype="int64") |
359 | | - y = expand_dims(x, "batch", size=size_sym_4) |
360 | | - fn = xr_function([x, size_sym_4], y, on_unused_input="ignore") |
361 | | - xr_assert_allclose(fn(xr_arange_like(x), 4), expected) |
362 | | - |
363 | | - # Reversibility: expand then squeeze |
364 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
365 | | - y = expand_dims(x, "batch") |
366 | | - z = squeeze(y, "batch") |
367 | | - fn = xr_function([x], z) |
368 | | - xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x)) |
369 | | - |
370 | | - # expand_dims with dim=None = no-op |
371 | | - x = xtensor("x", dims=("a",), shape=(3,)) |
372 | | - y = expand_dims(x, None) |
373 | | - fn = xr_function([x], y) |
374 | | - xr_assert_allclose(fn(xr_arange_like(x)), xr_arange_like(x)) |
375 | | - |
376 | | - # broadcast after symbolic size |
377 | | - size_sym = scalar("size_sym", dtype="int64") |
378 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3)) |
379 | | - y = expand_dims(x, "batch", size=size_sym) |
380 | | - z = y + y # triggers shape alignment |
381 | | - fn = xr_function([x, size_sym], z, on_unused_input="ignore") |
382 | | - x_test = xr_arange_like(x) |
383 | | - out = fn(x_test, 1) |
384 | | - expected = x_test.expand_dims("batch") + x_test.expand_dims("batch") |
385 | | - xr_assert_allclose(out, expected) |
386 | | - |
387 | | - |
388 | | -def test_expand_dims_errors(): |
389 | | - """Test error handling in expand_dims.""" |
390 | | - |
391 | | - # Expanding existing dim |
392 | | - x = xtensor("x", dims=("city",), shape=(3,)) |
393 | | - y = expand_dims(x, "country") |
394 | | - with pytest.raises(ValueError, match="already exists"): |
395 | | - expand_dims(y, "city") |
396 | | - |
397 | | - # Size = 0 is invalid |
398 | | - with pytest.raises(ValueError, match="size must be.*positive"): |
399 | | - expand_dims(x, "batch", size=0) |
400 | | - |
401 | | - # Invalid dim type |
402 | | - with pytest.raises(TypeError): |
403 | | - expand_dims(x, 123) |
404 | | - |
405 | | - # Invalid size type |
406 | | - with pytest.raises(TypeError): |
407 | | - expand_dims(x, "new", size=[1]) |
408 | | - |
409 | | - # Duplicate dimension creation |
410 | | - y = expand_dims(x, "new") |
411 | | - with pytest.raises(ValueError): |
412 | | - expand_dims(y, "new") |
413 | | - |
414 | | - # Symbolic size with invalid runtime value |
415 | | - size_sym = scalar("size_sym", dtype="int64") |
416 | | - y = expand_dims(x, "batch", size=size_sym) |
417 | | - fn = xr_function([x, size_sym], y, on_unused_input="ignore") |
418 | | - with pytest.raises(Exception): |
419 | | - fn(xr_arange_like(x), 0) |
420 | | - |
421 | | - |
422 | 269 | def test_squeeze_explicit_dims(): |
423 | 270 | """Test squeeze with explicit dimension(s).""" |
424 | 271 |
|
@@ -487,13 +334,13 @@ def test_squeeze_implicit_dims(): |
487 | 334 | fn4 = xr_function([x4], y4) |
488 | 335 | xr_assert_allclose(fn4(x4_test), x4_test.squeeze("b")) |
489 | 336 |
|
490 | | - # Reversibility with expand_dims |
491 | | - x5 = xtensor("x5", dims=("batch", "time", "feature"), shape=(2, 1, 3)) |
492 | | - y5 = squeeze(x5, "time") |
493 | | - z5 = expand_dims(y5, "time") |
494 | | - fn5 = xr_function([x5], z5) |
495 | | - x5_test = xr_arange_like(x5) |
496 | | - xr_assert_allclose(fn5(x5_test).transpose(*x5_test.dims), x5_test) |
| 337 | + # Reversibility with expand_dims (restore when expand_dims is implemented) |
| 338 | + # x5 = xtensor("x5", dims=("batch", "time", "feature"), shape=(2, 1, 3)) |
| 339 | + # y5 = squeeze(x5, "time") |
| 340 | + # z5 = expand_dims(y5, "time") |
| 341 | + # fn5 = xr_function([x5], z5) |
| 342 | + # x5_test = xr_arange_like(x5) |
| 343 | + # xr_assert_allclose(fn5(x5_test).transpose(*x5_test.dims), x5_test) |
497 | 344 |
|
498 | 345 | """ |
499 | 346 | This test documents that we intentionally don't squeeze dimensions with symbolic shapes |
|
0 commit comments