1313 _jit = lambda f : f
1414
1515
16- def _debug (mode : bool = __debug__ ) -> bool :
17- r"""Returns whether debugging is enabled or not.
16+ __piqa_debug__ = __debug__
17+
18+ def set_debug (mode : bool = False ) -> bool :
19+ r"""Sets and returns whether debugging is enabled or not.
20+ If `__debug__` is `False`, this function has not effect.
21+
22+ Example:
23+ >>> set_debug(False)
24+ False
1825 """
1926
20- return mode
27+ global __piqa_debug__
28+
29+ __piqa_debug__ = __debug__ and mode
30+
31+ return __piqa_debug__
2132
2233
23- def _assert_type (
34+ def assert_type (
2435 tensors : List [torch .Tensor ],
2536 device : torch .device ,
2637 dim_range : Tuple [int , int ] = (0 , - 1 ),
@@ -33,60 +44,60 @@ def _assert_type(
3344 Example:
3445 >>> x = torch.rand(5, 3, 256, 256)
3546 >>> y = torch.rand(5, 3, 256, 256)
36- >>> _assert_type ([x, y], device=x.device, dim_range=(4, 4), n_channels=3)
47+ >>> assert_type ([x, y], device=x.device, dim_range=(4, 4), n_channels=3)
3748 """
3849
39- if not _debug () :
50+ if not __piqa_debug__ :
4051 return
4152
4253 ref = tensors [0 ]
4354
4455 for t in tensors :
4556 assert t .device == device , (
46- f'Expected tensors to be on { device } , got { t .device } '
57+ f'Tensors expected to be on { device } , got { t .device } '
4758 )
4859
4960 assert t .shape == ref .shape , (
50- 'Expected tensors to be of the same shape, got'
61+ 'Tensors expected to be of the same shape, got'
5162 f' { ref .shape } and { t .shape } '
5263 )
5364
5465 if dim_range [0 ] == dim_range [1 ]:
5566 assert t .dim () == dim_range [0 ], (
56- 'Expected number of dimensions to be'
67+ 'Number of dimensions expected to be'
5768 f' { dim_range [0 ]} , got { t .dim ()} '
5869 )
5970 elif dim_range [0 ] < dim_range [1 ]:
6071 assert dim_range [0 ] <= t .dim () <= dim_range [1 ], (
61- 'Expected number of dimensions to be between'
72+ 'Number of dimensions expected to be between'
6273 f' { dim_range [0 ]} and { dim_range [1 ]} , got { t .dim ()} '
6374 )
6475 elif dim_range [0 ] > 0 :
6576 assert dim_range [0 ] <= t .dim (), (
66- 'Expected number of dimensions to be greater or equal to'
77+ 'Number of dimensions expected to be greater or equal to'
6778 f' { dim_range [0 ]} , got { t .dim ()} '
6879 )
6980
7081 if n_channels > 0 :
7182 assert t .size (1 ) == n_channels , (
72- 'Expected number of channels to be'
83+ 'Number of channels expected to be'
7384 f' { n_channels } , got { t .size (1 )} '
7485 )
7586
7687 if value_range [0 ] < value_range [1 ]:
7788 assert value_range [0 ] <= t .min (), (
78- 'Expected values to be greater or equal to'
89+ 'Values expected to be greater or equal to'
7990 f' { value_range [0 ]} , got { t .min ()} '
8091 )
8192
8293 assert t .max () <= value_range [1 ], (
83- 'Expected values to be lower or equal to'
94+ 'Values expected to be lower or equal to'
8495 f' { value_range [1 ]} , got { t .max ()} '
8596 )
8697
8798
8899@_jit
89- def _reduce (x : torch .Tensor , reduction : str = 'mean' ) -> torch .Tensor :
100+ def reduce_tensor (x : torch .Tensor , reduction : str = 'mean' ) -> torch .Tensor :
90101 r"""Returns the reduction of \(x\).
91102
92103 Args:
@@ -96,7 +107,7 @@ def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
96107
97108 Example:
98109 >>> x = torch.arange(5)
99- >>> _reduce (x, reduction='sum')
110+ >>> reduce_tensor (x, reduction='sum')
100111 tensor(10)
101112 """
102113
0 commit comments