|
7 | 7 | import inspect |
8 | 8 |
|
9 | 9 | import numpy as np |
10 | | -import xarray as xr |
11 | 10 | from xarray import DataArray |
12 | 11 |
|
13 | 12 | import pytensor.scalar as ps |
14 | | -import pytensor.xtensor as px |
15 | 13 | import pytensor.xtensor.math as pxm |
16 | 14 | from pytensor import function |
17 | 15 | from pytensor.scalar import ScalarOp |
@@ -316,148 +314,3 @@ def test_dot_errors(): |
316 | 314 | # Doesn't fail until the rewrite |
317 | 315 | with pytest.raises(ValueError, match="not aligned"): |
318 | 316 | fn(x_test, y_test) |
319 | | - |
320 | | - |
321 | | -def test_full_like(): |
322 | | - """Test full_like function, comparing with xarray's full_like.""" |
323 | | - |
324 | | - # Basic functionality with scalar fill_value |
325 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") |
326 | | - x_test = xr_arange_like(x) |
327 | | - |
328 | | - y1 = px.full_like(x, 5.0) |
329 | | - fn1 = xr_function([x], y1) |
330 | | - result1 = fn1(x_test) |
331 | | - expected1 = xr.full_like(x_test, 5.0) |
332 | | - xr_assert_allclose(result1, expected1, check_dtype=True) |
333 | | - |
334 | | - # Other dtypes |
335 | | - x_3d = xtensor("x_3d", dims=("a", "b", "c"), shape=(2, 3, 4), dtype="float32") |
336 | | - x_3d_test = xr_arange_like(x_3d) |
337 | | - |
338 | | - y7 = px.full_like(x_3d, -1.0) |
339 | | - fn7 = xr_function([x_3d], y7) |
340 | | - result7 = fn7(x_3d_test) |
341 | | - expected7 = xr.full_like(x_3d_test, -1.0) |
342 | | - xr_assert_allclose(result7, expected7, check_dtype=True) |
343 | | - |
344 | | - # Integer dtype |
345 | | - y3 = px.full_like(x, 5.0, dtype="int32") |
346 | | - fn3 = xr_function([x], y3) |
347 | | - result3 = fn3(x_test) |
348 | | - expected3 = xr.full_like(x_test, 5.0, dtype="int32") |
349 | | - xr_assert_allclose(result3, expected3, check_dtype=True) |
350 | | - |
351 | | - # Different fill_value types |
352 | | - y4 = px.full_like(x, np.array(3.14)) |
353 | | - fn4 = xr_function([x], y4) |
354 | | - result4 = fn4(x_test) |
355 | | - expected4 = xr.full_like(x_test, 3.14) |
356 | | - xr_assert_allclose(result4, expected4, check_dtype=True) |
357 | | - |
358 | | - # Integer input with float fill_value |
359 | | - x_int = xtensor("x_int", dims=("a", "b"), shape=(2, 3), dtype="int32") |
360 | | - x_int_test = DataArray(np.arange(6, dtype="int32").reshape(2, 3), dims=("a", "b")) |
361 | | - |
362 | | - y5 = px.full_like(x_int, 2.5) |
363 | | - fn5 = xr_function([x_int], y5) |
364 | | - result5 = fn5(x_int_test) |
365 | | - expected5 = xr.full_like(x_int_test, 2.5) |
366 | | - xr_assert_allclose(result5, expected5, check_dtype=True) |
367 | | - |
368 | | - # Symbolic shapes |
369 | | - x_sym = xtensor("x_sym", dims=("a", "b"), shape=(None, 3)) |
370 | | - x_sym_test = DataArray( |
371 | | - np.arange(6, dtype=x_sym.type.dtype).reshape(2, 3), dims=("a", "b") |
372 | | - ) |
373 | | - |
374 | | - y6 = px.full_like(x_sym, 7.0) |
375 | | - fn6 = xr_function([x_sym], y6) |
376 | | - result6 = fn6(x_sym_test) |
377 | | - expected6 = xr.full_like(x_sym_test, 7.0) |
378 | | - xr_assert_allclose(result6, expected6, check_dtype=True) |
379 | | - |
380 | | - # Boolean dtype |
381 | | - x_bool = xtensor("x_bool", dims=("a", "b"), shape=(2, 3), dtype="bool") |
382 | | - x_bool_test = DataArray( |
383 | | - np.array([[True, False, True], [False, True, False]]), dims=("a", "b") |
384 | | - ) |
385 | | - |
386 | | - y8 = px.full_like(x_bool, True) |
387 | | - fn8 = xr_function([x_bool], y8) |
388 | | - result8 = fn8(x_bool_test) |
389 | | - expected8 = xr.full_like(x_bool_test, True) |
390 | | - xr_assert_allclose(result8, expected8, check_dtype=True) |
391 | | - |
392 | | - # Complex dtype |
393 | | - x_complex = xtensor("x_complex", dims=("a", "b"), shape=(2, 3), dtype="complex64") |
394 | | - x_complex_test = DataArray( |
395 | | - np.arange(6, dtype="complex64").reshape(2, 3), dims=("a", "b") |
396 | | - ) |
397 | | - |
398 | | - y9 = px.full_like(x_complex, 1 + 2j) |
399 | | - fn9 = xr_function([x_complex], y9) |
400 | | - result9 = fn9(x_complex_test) |
401 | | - expected9 = xr.full_like(x_complex_test, 1 + 2j) |
402 | | - xr_assert_allclose(result9, expected9, check_dtype=True) |
403 | | - |
404 | | - # Symbolic fill value |
405 | | - x_sym_fill = xtensor("x_sym_fill", dims=("a", "b"), shape=(2, 3), dtype="float64") |
406 | | - fill_val = xtensor("fill_val", dims=(), shape=(), dtype="float64") |
407 | | - x_sym_fill_test = xr_arange_like(x_sym_fill) |
408 | | - fill_val_test = DataArray(3.14, dims=()) |
409 | | - |
410 | | - y10 = px.full_like(x_sym_fill, fill_val) |
411 | | - fn10 = xr_function([x_sym_fill, fill_val], y10) |
412 | | - result10 = fn10(x_sym_fill_test, fill_val_test) |
413 | | - expected10 = xr.full_like(x_sym_fill_test, 3.14) |
414 | | - xr_assert_allclose(result10, expected10, check_dtype=True) |
415 | | - |
416 | | - # Test dtype conversion to bool when neither input nor fill_value are bool |
417 | | - x_float = xtensor("x_float", dims=("a", "b"), shape=(2, 3), dtype="float64") |
418 | | - x_float_test = xr_arange_like(x_float) |
419 | | - |
420 | | - y11 = px.full_like(x_float, 5.0, dtype="bool") |
421 | | - fn11 = xr_function([x_float], y11) |
422 | | - result11 = fn11(x_float_test) |
423 | | - expected11 = xr.full_like(x_float_test, 5.0, dtype="bool") |
424 | | - xr_assert_allclose(result11, expected11, check_dtype=True) |
425 | | - |
426 | | - # Verify the result is actually boolean |
427 | | - assert result11.dtype == "bool" |
428 | | - assert expected11.dtype == "bool" |
429 | | - |
430 | | - |
431 | | -def test_full_like_errors(): |
432 | | - """Test full_like function errors.""" |
433 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") |
434 | | - x_test = xr_arange_like(x) |
435 | | - |
436 | | - with pytest.raises(ValueError, match="fill_value must be a scalar"): |
437 | | - px.full_like(x, x_test) |
438 | | - |
439 | | - |
440 | | -def test_ones_like(): |
441 | | - """Test ones_like function, comparing with xarray's ones_like.""" |
442 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") |
443 | | - x_test = xr_arange_like(x) |
444 | | - |
445 | | - y1 = px.ones_like(x) |
446 | | - fn1 = xr_function([x], y1) |
447 | | - result1 = fn1(x_test) |
448 | | - expected1 = xr.ones_like(x_test) |
449 | | - xr_assert_allclose(result1, expected1) |
450 | | - assert result1.dtype == expected1.dtype |
451 | | - |
452 | | - |
453 | | -def test_zeros_like(): |
454 | | - """Test zeros_like function, comparing with xarray's zeros_like.""" |
455 | | - x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") |
456 | | - x_test = xr_arange_like(x) |
457 | | - |
458 | | - y1 = px.zeros_like(x) |
459 | | - fn1 = xr_function([x], y1) |
460 | | - result1 = fn1(x_test) |
461 | | - expected1 = xr.zeros_like(x_test) |
462 | | - xr_assert_allclose(result1, expected1) |
463 | | - assert result1.dtype == expected1.dtype |
0 commit comments