|
11 | 11 | from xarray import DataArray |
12 | 12 |
|
13 | 13 | import pytensor.scalar as ps |
| 14 | +import pytensor.xtensor as px |
14 | 15 | import pytensor.xtensor.math as pxm |
15 | 16 | from pytensor import function |
16 | 17 | from pytensor.scalar import ScalarOp |
@@ -324,100 +325,139 @@ def test_full_like(): |
324 | 325 | x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") |
325 | 326 | x_test = xr_arange_like(x) |
326 | 327 |
|
327 | | - y1 = pxm.full_like(x, 5.0) |
| 328 | + y1 = px.full_like(x, 5.0) |
328 | 329 | fn1 = xr_function([x], y1) |
329 | 330 | result1 = fn1(x_test) |
330 | 331 | expected1 = xr.full_like(x_test, 5.0) |
331 | | - xr_assert_allclose(result1, expected1) |
| 332 | + xr_assert_allclose(result1, expected1, check_dtype=True) |
| 333 | + |
| 334 | + # Other dtypes |
| 335 | + x_3d = xtensor("x_3d", dims=("a", "b", "c"), shape=(2, 3, 4), dtype="float32") |
| 336 | + x_3d_test = xr_arange_like(x_3d) |
332 | 337 |
|
333 | | - # Different dtypes |
334 | | - y3 = pxm.full_like(x, 5.0, dtype="int32") |
| 338 | + y7 = px.full_like(x_3d, -1.0) |
| 339 | + fn7 = xr_function([x_3d], y7) |
| 340 | + result7 = fn7(x_3d_test) |
| 341 | + expected7 = xr.full_like(x_3d_test, -1.0) |
| 342 | + xr_assert_allclose(result7, expected7, check_dtype=True) |
| 343 | + |
| 344 | + # Integer dtype |
| 345 | + y3 = px.full_like(x, 5.0, dtype="int32") |
335 | 346 | fn3 = xr_function([x], y3) |
336 | 347 | result3 = fn3(x_test) |
337 | 348 | expected3 = xr.full_like(x_test, 5.0, dtype="int32") |
338 | | - xr_assert_allclose(result3, expected3) |
| 349 | + xr_assert_allclose(result3, expected3, check_dtype=True) |
339 | 350 |
|
340 | 351 | # Different fill_value types |
341 | | - y4 = pxm.full_like(x, np.array(3.14)) |
| 352 | + y4 = px.full_like(x, np.array(3.14)) |
342 | 353 | fn4 = xr_function([x], y4) |
343 | 354 | result4 = fn4(x_test) |
344 | 355 | expected4 = xr.full_like(x_test, 3.14) |
345 | | - xr_assert_allclose(result4, expected4) |
| 356 | + xr_assert_allclose(result4, expected4, check_dtype=True) |
346 | 357 |
|
347 | 358 | # Integer input with float fill_value |
348 | 359 | x_int = xtensor("x_int", dims=("a", "b"), shape=(2, 3), dtype="int32") |
349 | 360 | x_int_test = DataArray(np.arange(6, dtype="int32").reshape(2, 3), dims=("a", "b")) |
350 | 361 |
|
351 | | - y5 = pxm.full_like(x_int, 2.5) |
| 362 | + y5 = px.full_like(x_int, 2.5) |
352 | 363 | fn5 = xr_function([x_int], y5) |
353 | 364 | result5 = fn5(x_int_test) |
354 | 365 | expected5 = xr.full_like(x_int_test, 2.5) |
355 | | - xr_assert_allclose(result5, expected5) |
| 366 | + xr_assert_allclose(result5, expected5, check_dtype=True) |
356 | 367 |
|
357 | 368 | # Symbolic shapes |
358 | 369 | x_sym = xtensor("x_sym", dims=("a", "b"), shape=(None, 3)) |
359 | | - x_sym_test = DataArray(np.arange(6).reshape(2, 3), dims=("a", "b")) |
| 370 | + x_sym_test = DataArray( |
| 371 | + np.arange(6, dtype=x_sym.type.dtype).reshape(2, 3), dims=("a", "b") |
| 372 | + ) |
360 | 373 |
|
361 | | - y6 = pxm.full_like(x_sym, 7.0) |
| 374 | + y6 = px.full_like(x_sym, 7.0) |
362 | 375 | fn6 = xr_function([x_sym], y6) |
363 | 376 | result6 = fn6(x_sym_test) |
364 | 377 | expected6 = xr.full_like(x_sym_test, 7.0) |
365 | | - xr_assert_allclose(result6, expected6) |
366 | | - |
367 | | - # Higher dimensional tensor |
368 | | - x_3d = xtensor("x_3d", dims=("a", "b", "c"), shape=(2, 3, 4), dtype="float32") |
369 | | - x_3d_test = xr_arange_like(x_3d) |
370 | | - |
371 | | - y7 = pxm.full_like(x_3d, -1.0) |
372 | | - fn7 = xr_function([x_3d], y7) |
373 | | - result7 = fn7(x_3d_test) |
374 | | - expected7 = xr.full_like(x_3d_test, -1.0) |
375 | | - xr_assert_allclose(result7, expected7) |
| 378 | + xr_assert_allclose(result6, expected6, check_dtype=True) |
376 | 379 |
|
377 | 380 | # Boolean dtype |
378 | 381 | x_bool = xtensor("x_bool", dims=("a", "b"), shape=(2, 3), dtype="bool") |
379 | 382 | x_bool_test = DataArray( |
380 | 383 | np.array([[True, False, True], [False, True, False]]), dims=("a", "b") |
381 | 384 | ) |
382 | 385 |
|
383 | | - y8 = pxm.full_like(x_bool, True) |
| 386 | + y8 = px.full_like(x_bool, True) |
384 | 387 | fn8 = xr_function([x_bool], y8) |
385 | 388 | result8 = fn8(x_bool_test) |
386 | 389 | expected8 = xr.full_like(x_bool_test, True) |
387 | | - xr_assert_allclose(result8, expected8) |
| 390 | + xr_assert_allclose(result8, expected8, check_dtype=True) |
388 | 391 |
|
389 | 392 | # Complex dtype |
390 | 393 | x_complex = xtensor("x_complex", dims=("a", "b"), shape=(2, 3), dtype="complex64") |
391 | 394 | x_complex_test = DataArray( |
392 | 395 | np.arange(6, dtype="complex64").reshape(2, 3), dims=("a", "b") |
393 | 396 | ) |
394 | 397 |
|
395 | | - y9 = pxm.full_like(x_complex, 1 + 2j) |
| 398 | + y9 = px.full_like(x_complex, 1 + 2j) |
396 | 399 | fn9 = xr_function([x_complex], y9) |
397 | 400 | result9 = fn9(x_complex_test) |
398 | 401 | expected9 = xr.full_like(x_complex_test, 1 + 2j) |
399 | | - xr_assert_allclose(result9, expected9) |
| 402 | + xr_assert_allclose(result9, expected9, check_dtype=True) |
| 403 | + |
| 404 | + # Symbolic fill value |
| 405 | + x_sym_fill = xtensor("x_sym_fill", dims=("a", "b"), shape=(2, 3), dtype="float64") |
| 406 | + fill_val = xtensor("fill_val", dims=(), shape=(), dtype="float64") |
| 407 | + x_sym_fill_test = xr_arange_like(x_sym_fill) |
| 408 | + fill_val_test = DataArray(3.14, dims=()) |
| 409 | + |
| 410 | + y10 = px.full_like(x_sym_fill, fill_val) |
| 411 | + fn10 = xr_function([x_sym_fill, fill_val], y10) |
| 412 | + result10 = fn10(x_sym_fill_test, fill_val_test) |
| 413 | + expected10 = xr.full_like(x_sym_fill_test, 3.14) |
| 414 | + xr_assert_allclose(result10, expected10, check_dtype=True) |
| 415 | + |
| 416 | + # Test dtype conversion to bool when neither input nor fill_value are bool |
| 417 | + x_float = xtensor("x_float", dims=("a", "b"), shape=(2, 3), dtype="float64") |
| 418 | + x_float_test = xr_arange_like(x_float) |
| 419 | + |
| 420 | + y11 = px.full_like(x_float, 5.0, dtype="bool") |
| 421 | + fn11 = xr_function([x_float], y11) |
| 422 | + result11 = fn11(x_float_test) |
| 423 | + expected11 = xr.full_like(x_float_test, 5.0, dtype="bool") |
| 424 | + xr_assert_allclose(result11, expected11, check_dtype=True) |
| 425 | + |
| 426 | + # Verify the result is actually boolean |
| 427 | + assert result11.dtype == "bool" |
| 428 | + assert expected11.dtype == "bool" |
| 429 | + |
| 430 | + |
| 431 | +def test_full_like_errors(): |
| 432 | + """Test full_like function errors.""" |
| 433 | + x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") |
| 434 | + x_test = xr_arange_like(x) |
| 435 | + |
| 436 | + with pytest.raises(ValueError, match="fill_value must be a scalar"): |
| 437 | + px.full_like(x, x_test) |
400 | 438 |
|
401 | 439 |
|
402 | 440 | def test_ones_like(): |
403 | 441 | """Test ones_like function, comparing with xarray's ones_like.""" |
404 | 442 | x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") |
405 | 443 | x_test = xr_arange_like(x) |
406 | 444 |
|
407 | | - y1 = pxm.ones_like(x) |
| 445 | + y1 = px.ones_like(x) |
408 | 446 | fn1 = xr_function([x], y1) |
409 | 447 | result1 = fn1(x_test) |
410 | 448 | expected1 = xr.ones_like(x_test) |
411 | 449 | xr_assert_allclose(result1, expected1) |
| 450 | + assert result1.dtype == expected1.dtype |
412 | 451 |
|
413 | 452 |
|
414 | 453 | def test_zeros_like(): |
415 | 454 | """Test zeros_like function, comparing with xarray's zeros_like.""" |
416 | 455 | x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64") |
417 | 456 | x_test = xr_arange_like(x) |
418 | 457 |
|
419 | | - y1 = pxm.zeros_like(x) |
| 458 | + y1 = px.zeros_like(x) |
420 | 459 | fn1 = xr_function([x], y1) |
421 | 460 | result1 = fn1(x_test) |
422 | 461 | expected1 = xr.zeros_like(x_test) |
423 | 462 | xr_assert_allclose(result1, expected1) |
| 463 | + assert result1.dtype == expected1.dtype |
0 commit comments