Skip to content

Commit 77e835e

Browse files
committed
Moving from math to shape
1 parent 91465b8 commit 77e835e

File tree

5 files changed

+250
-253
lines changed

5 files changed

+250
-253
lines changed

pytensor/xtensor/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import pytensor.xtensor.rewriting
44
from pytensor.xtensor import linalg
5-
from pytensor.xtensor.math import dot, full_like, ones_like, zeros_like
6-
from pytensor.xtensor.shape import concat
5+
from pytensor.xtensor.math import dot
6+
from pytensor.xtensor.shape import concat, full_like, ones_like, zeros_like
77
from pytensor.xtensor.type import (
88
as_xtensor,
99
xtensor,

pytensor/xtensor/math.py

Lines changed: 0 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -250,103 +250,3 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
250250
result = XDot(dims=tuple(dim_set))(x, y)
251251

252252
return result
253-
254-
255-
def full_like(x, fill_value, dtype=None):
256-
"""Create a new XTensorVariable with the same shape and dimensions, filled with a specified value.
257-
258-
Parameters
259-
----------
260-
x : XTensorVariable
261-
The tensor to fill.
262-
fill_value : scalar or XTensorVariable
263-
The value to fill the new tensor with.
264-
dtype : str or np.dtype, optional
265-
The data type of the new tensor. If None, uses the dtype of the input tensor.
266-
267-
Returns
268-
-------
269-
XTensorVariable
270-
A new tensor with the same shape and dimensions as self, filled with fill_value.
271-
272-
Examples
273-
--------
274-
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
275-
>>> y = full_like(x, 5.0)
276-
>>> y.dims
277-
('a', 'b')
278-
>>> y.shape
279-
(2, 3)
280-
"""
281-
x = as_xtensor(x)
282-
fill_value = as_xtensor(fill_value)
283-
284-
# Check that fill_value is a scalar (ndim=0)
285-
if fill_value.type.ndim != 0:
286-
raise ValueError(
287-
f"fill_value must be a scalar, got ndim={fill_value.type.ndim}"
288-
)
289-
290-
# Handle dtype conversion
291-
if dtype is not None:
292-
# If dtype is specified, cast the fill_value to that dtype
293-
fill_value = cast(fill_value, dtype)
294-
else:
295-
# If dtype is None, cast the fill_value to the input tensor's dtype
296-
# This matches xarray's behavior where it preserves the original dtype
297-
fill_value = cast(fill_value, x.type.dtype)
298-
299-
# Use the xtensor second function
300-
return second(x, fill_value)
301-
302-
303-
def ones_like(x, dtype=None):
304-
"""Create a new XTensorVariable with the same shape and dimensions, filled with ones.
305-
306-
Parameters
307-
----------
308-
x : XTensorVariable
309-
The tensor to fill.
310-
dtype : str or np.dtype, optional
311-
The data type of the new tensor. If None, uses the dtype of the input tensor.
312-
313-
Returns:
314-
XTensorVariable
315-
A new tensor with the same shape and dimensions as self, filled with ones.
316-
317-
Examples
318-
--------
319-
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
320-
>>> y = ones_like(x)
321-
>>> y.dims
322-
('a', 'b')
323-
>>> y.shape
324-
(2, 3)
325-
"""
326-
return full_like(x, 1.0, dtype=dtype)
327-
328-
329-
def zeros_like(x, dtype=None):
330-
"""Create a new XTensorVariable with the same shape and dimensions, filled with zeros.
331-
332-
Parameters
333-
----------
334-
x : XTensorVariable
335-
The tensor to fill.
336-
dtype : str or np.dtype, optional
337-
The data type of the new tensor. If None, uses the dtype of the input tensor.
338-
339-
Returns:
340-
XTensorVariable
341-
A new tensor with the same shape and dimensions as self, filled with zeros.
342-
343-
Examples
344-
--------
345-
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
346-
>>> y = zeros_like(x)
347-
>>> y.dims
348-
('a', 'b')
349-
>>> y.shape
350-
(2, 3)
351-
"""
352-
return full_like(x, 0.0, dtype=dtype)

pytensor/xtensor/shape.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pytensor.tensor.type import integer_dtypes
1414
from pytensor.tensor.utils import get_static_shape_from_size_variables
1515
from pytensor.xtensor.basic import XOp
16+
from pytensor.xtensor.math import cast, second
1617
from pytensor.xtensor.type import as_xtensor, xtensor
1718

1819

@@ -498,3 +499,100 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa
498499
x = Transpose(dims=tuple(target_dims))(x)
499500

500501
return x
502+
503+
504+
def full_like(x, fill_value, dtype=None):
505+
"""Create a new XTensorVariable with the same shape and dimensions, filled with a specified value.
506+
507+
Parameters
508+
----------
509+
x : XTensorVariable
510+
The tensor to fill.
511+
fill_value : scalar or XTensorVariable
512+
The value to fill the new tensor with.
513+
dtype : str or np.dtype, optional
514+
The data type of the new tensor. If None, uses the dtype of the input tensor.
515+
516+
Returns
517+
-------
518+
XTensorVariable
519+
A new tensor with the same shape and dimensions as self, filled with fill_value.
520+
521+
Examples
522+
--------
523+
>>> from pytensor.xtensor import xtensor, full_like
524+
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
525+
>>> y = full_like(x, 5.0)
526+
>>> assert y.dims == ("a", "b")
527+
>>> assert y.type.shape == (2, 3)
528+
"""
529+
x = as_xtensor(x)
530+
fill_value = as_xtensor(fill_value)
531+
532+
# Check that fill_value is a scalar (ndim=0)
533+
if fill_value.type.ndim != 0:
534+
raise ValueError(
535+
f"fill_value must be a scalar, got ndim={fill_value.type.ndim}"
536+
)
537+
538+
# Handle dtype conversion
539+
if dtype is not None:
540+
# If dtype is specified, cast the fill_value to that dtype
541+
fill_value = cast(fill_value, dtype)
542+
else:
543+
# If dtype is None, cast the fill_value to the input tensor's dtype
544+
# This matches xarray's behavior where it preserves the original dtype
545+
fill_value = cast(fill_value, x.type.dtype)
546+
547+
# Use the xtensor second function
548+
return second(x, fill_value)
549+
550+
551+
def ones_like(x, dtype=None):
552+
"""Create a new XTensorVariable with the same shape and dimensions, filled with ones.
553+
554+
Parameters
555+
----------
556+
x : XTensorVariable
557+
The tensor to fill.
558+
dtype : str or np.dtype, optional
559+
The data type of the new tensor. If None, uses the dtype of the input tensor.
560+
561+
Returns:
562+
XTensorVariable
563+
A new tensor with the same shape and dimensions as self, filled with ones.
564+
565+
Examples
566+
--------
567+
>>> from pytensor.xtensor import xtensor, full_like
568+
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
569+
>>> y = ones_like(x)
570+
>>> assert y.dims == ("a", "b")
571+
>>> assert y.type.shape == (2, 3)
572+
"""
573+
return full_like(x, 1.0, dtype=dtype)
574+
575+
576+
def zeros_like(x, dtype=None):
577+
"""Create a new XTensorVariable with the same shape and dimensions, filled with zeros.
578+
579+
Parameters
580+
----------
581+
x : XTensorVariable
582+
The tensor to fill.
583+
dtype : str or np.dtype, optional
584+
The data type of the new tensor. If None, uses the dtype of the input tensor.
585+
586+
Returns:
587+
XTensorVariable
588+
A new tensor with the same shape and dimensions as self, filled with zeros.
589+
590+
Examples
591+
--------
592+
>>> from pytensor.xtensor import xtensor, full_like
593+
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
594+
>>> y = zeros_like(x)
595+
>>> assert y.dims == ("a", "b")
596+
>>> assert y.type.shape == (2, 3)
597+
"""
598+
return full_like(x, 0.0, dtype=dtype)

tests/xtensor/test_math.py

Lines changed: 0 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77
import inspect
88

99
import numpy as np
10-
import xarray as xr
1110
from xarray import DataArray
1211

1312
import pytensor.scalar as ps
14-
import pytensor.xtensor as px
1513
import pytensor.xtensor.math as pxm
1614
from pytensor import function
1715
from pytensor.scalar import ScalarOp
@@ -316,148 +314,3 @@ def test_dot_errors():
316314
# Doesn't fail until the rewrite
317315
with pytest.raises(ValueError, match="not aligned"):
318316
fn(x_test, y_test)
319-
320-
321-
def test_full_like():
322-
"""Test full_like function, comparing with xarray's full_like."""
323-
324-
# Basic functionality with scalar fill_value
325-
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
326-
x_test = xr_arange_like(x)
327-
328-
y1 = px.full_like(x, 5.0)
329-
fn1 = xr_function([x], y1)
330-
result1 = fn1(x_test)
331-
expected1 = xr.full_like(x_test, 5.0)
332-
xr_assert_allclose(result1, expected1, check_dtype=True)
333-
334-
# Other dtypes
335-
x_3d = xtensor("x_3d", dims=("a", "b", "c"), shape=(2, 3, 4), dtype="float32")
336-
x_3d_test = xr_arange_like(x_3d)
337-
338-
y7 = px.full_like(x_3d, -1.0)
339-
fn7 = xr_function([x_3d], y7)
340-
result7 = fn7(x_3d_test)
341-
expected7 = xr.full_like(x_3d_test, -1.0)
342-
xr_assert_allclose(result7, expected7, check_dtype=True)
343-
344-
# Integer dtype
345-
y3 = px.full_like(x, 5.0, dtype="int32")
346-
fn3 = xr_function([x], y3)
347-
result3 = fn3(x_test)
348-
expected3 = xr.full_like(x_test, 5.0, dtype="int32")
349-
xr_assert_allclose(result3, expected3, check_dtype=True)
350-
351-
# Different fill_value types
352-
y4 = px.full_like(x, np.array(3.14))
353-
fn4 = xr_function([x], y4)
354-
result4 = fn4(x_test)
355-
expected4 = xr.full_like(x_test, 3.14)
356-
xr_assert_allclose(result4, expected4, check_dtype=True)
357-
358-
# Integer input with float fill_value
359-
x_int = xtensor("x_int", dims=("a", "b"), shape=(2, 3), dtype="int32")
360-
x_int_test = DataArray(np.arange(6, dtype="int32").reshape(2, 3), dims=("a", "b"))
361-
362-
y5 = px.full_like(x_int, 2.5)
363-
fn5 = xr_function([x_int], y5)
364-
result5 = fn5(x_int_test)
365-
expected5 = xr.full_like(x_int_test, 2.5)
366-
xr_assert_allclose(result5, expected5, check_dtype=True)
367-
368-
# Symbolic shapes
369-
x_sym = xtensor("x_sym", dims=("a", "b"), shape=(None, 3))
370-
x_sym_test = DataArray(
371-
np.arange(6, dtype=x_sym.type.dtype).reshape(2, 3), dims=("a", "b")
372-
)
373-
374-
y6 = px.full_like(x_sym, 7.0)
375-
fn6 = xr_function([x_sym], y6)
376-
result6 = fn6(x_sym_test)
377-
expected6 = xr.full_like(x_sym_test, 7.0)
378-
xr_assert_allclose(result6, expected6, check_dtype=True)
379-
380-
# Boolean dtype
381-
x_bool = xtensor("x_bool", dims=("a", "b"), shape=(2, 3), dtype="bool")
382-
x_bool_test = DataArray(
383-
np.array([[True, False, True], [False, True, False]]), dims=("a", "b")
384-
)
385-
386-
y8 = px.full_like(x_bool, True)
387-
fn8 = xr_function([x_bool], y8)
388-
result8 = fn8(x_bool_test)
389-
expected8 = xr.full_like(x_bool_test, True)
390-
xr_assert_allclose(result8, expected8, check_dtype=True)
391-
392-
# Complex dtype
393-
x_complex = xtensor("x_complex", dims=("a", "b"), shape=(2, 3), dtype="complex64")
394-
x_complex_test = DataArray(
395-
np.arange(6, dtype="complex64").reshape(2, 3), dims=("a", "b")
396-
)
397-
398-
y9 = px.full_like(x_complex, 1 + 2j)
399-
fn9 = xr_function([x_complex], y9)
400-
result9 = fn9(x_complex_test)
401-
expected9 = xr.full_like(x_complex_test, 1 + 2j)
402-
xr_assert_allclose(result9, expected9, check_dtype=True)
403-
404-
# Symbolic fill value
405-
x_sym_fill = xtensor("x_sym_fill", dims=("a", "b"), shape=(2, 3), dtype="float64")
406-
fill_val = xtensor("fill_val", dims=(), shape=(), dtype="float64")
407-
x_sym_fill_test = xr_arange_like(x_sym_fill)
408-
fill_val_test = DataArray(3.14, dims=())
409-
410-
y10 = px.full_like(x_sym_fill, fill_val)
411-
fn10 = xr_function([x_sym_fill, fill_val], y10)
412-
result10 = fn10(x_sym_fill_test, fill_val_test)
413-
expected10 = xr.full_like(x_sym_fill_test, 3.14)
414-
xr_assert_allclose(result10, expected10, check_dtype=True)
415-
416-
# Test dtype conversion to bool when neither input nor fill_value are bool
417-
x_float = xtensor("x_float", dims=("a", "b"), shape=(2, 3), dtype="float64")
418-
x_float_test = xr_arange_like(x_float)
419-
420-
y11 = px.full_like(x_float, 5.0, dtype="bool")
421-
fn11 = xr_function([x_float], y11)
422-
result11 = fn11(x_float_test)
423-
expected11 = xr.full_like(x_float_test, 5.0, dtype="bool")
424-
xr_assert_allclose(result11, expected11, check_dtype=True)
425-
426-
# Verify the result is actually boolean
427-
assert result11.dtype == "bool"
428-
assert expected11.dtype == "bool"
429-
430-
431-
def test_full_like_errors():
432-
"""Test full_like function errors."""
433-
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
434-
x_test = xr_arange_like(x)
435-
436-
with pytest.raises(ValueError, match="fill_value must be a scalar"):
437-
px.full_like(x, x_test)
438-
439-
440-
def test_ones_like():
441-
"""Test ones_like function, comparing with xarray's ones_like."""
442-
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
443-
x_test = xr_arange_like(x)
444-
445-
y1 = px.ones_like(x)
446-
fn1 = xr_function([x], y1)
447-
result1 = fn1(x_test)
448-
expected1 = xr.ones_like(x_test)
449-
xr_assert_allclose(result1, expected1)
450-
assert result1.dtype == expected1.dtype
451-
452-
453-
def test_zeros_like():
454-
"""Test zeros_like function, comparing with xarray's zeros_like."""
455-
x = xtensor("x", dims=("a", "b"), shape=(2, 3), dtype="float64")
456-
x_test = xr_arange_like(x)
457-
458-
y1 = px.zeros_like(x)
459-
fn1 = xr_function([x], y1)
460-
result1 = fn1(x_test)
461-
expected1 = xr.zeros_like(x_test)
462-
xr_assert_allclose(result1, expected1)
463-
assert result1.dtype == expected1.dtype

0 commit comments

Comments
 (0)