@@ -583,7 +583,12 @@ def max_and_argmax(a, axis=None, keepdims=False):
583
583
return [out , argout ]
584
584
585
585
586
- class NonZeroCAReduce (CAReduce ):
586
+ class FixedOpCAReduce (CAReduce ):
587
+ def __str__ (self ):
588
+ return f"{ type (self ).__name__ } {{{ self ._axis_str ()} }}"
589
+
590
+
591
+ class NonZeroDimsCAReduce (FixedOpCAReduce ):
587
592
def _c_all (self , node , name , inames , onames , sub ):
588
593
decl , checks , alloc , loop , end = super ()._c_all (node , name , inames , onames , sub )
589
594
@@ -614,7 +619,7 @@ def _c_all(self, node, name, inames, onames, sub):
614
619
return decl , checks , alloc , loop , end
615
620
616
621
617
- class Max (NonZeroCAReduce ):
622
+ class Max (NonZeroDimsCAReduce ):
618
623
nfunc_spec = ("max" , 1 , 1 )
619
624
620
625
def __init__ (self , axis ):
@@ -625,7 +630,7 @@ def clone(self, **kwargs):
625
630
return type (self )(axis = axis )
626
631
627
632
628
- class Min (NonZeroCAReduce ):
633
+ class Min (NonZeroDimsCAReduce ):
629
634
nfunc_spec = ("min" , 1 , 1 )
630
635
631
636
def __init__ (self , axis ):
@@ -1496,7 +1501,7 @@ def complex_from_polar(abs, angle):
1496
1501
"""Return complex-valued tensor from polar coordinate specification."""
1497
1502
1498
1503
1499
- class Mean (CAReduce ):
1504
+ class Mean (FixedOpCAReduce ):
1500
1505
__props__ = ("axis" ,)
1501
1506
nfunc_spec = ("mean" , 1 , 1 )
1502
1507
@@ -2356,7 +2361,7 @@ def outer(x, y):
2356
2361
return dot (x .dimshuffle (0 , "x" ), y .dimshuffle ("x" , 0 ))
2357
2362
2358
2363
2359
- class All (CAReduce ):
2364
+ class All (FixedOpCAReduce ):
2360
2365
"""Applies `logical and` to all the values of a tensor along the
2361
2366
specified axis(es).
2362
2367
@@ -2370,12 +2375,6 @@ def __init__(self, axis=None):
2370
2375
def _output_dtype (self , idtype ):
2371
2376
return "bool"
2372
2377
2373
- def __str__ (self ):
2374
- if self .axis is None :
2375
- return "All"
2376
- else :
2377
- return "All{%s}" % ", " .join (map (str , self .axis ))
2378
-
2379
2378
def make_node (self , input ):
2380
2379
input = as_tensor_variable (input )
2381
2380
if input .dtype != "bool" :
@@ -2392,7 +2391,7 @@ def clone(self, **kwargs):
2392
2391
return type (self )(axis = axis )
2393
2392
2394
2393
2395
- class Any (CAReduce ):
2394
+ class Any (FixedOpCAReduce ):
2396
2395
"""Applies `bitwise or` to all the values of a tensor along the
2397
2396
specified axis(es).
2398
2397
@@ -2406,12 +2405,6 @@ def __init__(self, axis=None):
2406
2405
def _output_dtype (self , idtype ):
2407
2406
return "bool"
2408
2407
2409
- def __str__ (self ):
2410
- if self .axis is None :
2411
- return "Any"
2412
- else :
2413
- return "Any{%s}" % ", " .join (map (str , self .axis ))
2414
-
2415
2408
def make_node (self , input ):
2416
2409
input = as_tensor_variable (input )
2417
2410
if input .dtype != "bool" :
@@ -2428,7 +2421,7 @@ def clone(self, **kwargs):
2428
2421
return type (self )(axis = axis )
2429
2422
2430
2423
2431
- class Sum (CAReduce ):
2424
+ class Sum (FixedOpCAReduce ):
2432
2425
"""
2433
2426
Sums all the values of a tensor along the specified axis(es).
2434
2427
@@ -2449,14 +2442,6 @@ def __init__(self, axis=None, dtype=None, acc_dtype=None):
2449
2442
upcast_discrete_output = True ,
2450
2443
)
2451
2444
2452
- def __str__ (self ):
2453
- name = self .__class__ .__name__
2454
- axis = ""
2455
- if self .axis is not None :
2456
- axis = ", " .join (str (x ) for x in self .axis )
2457
- axis = f"axis=[{ axis } ], "
2458
- return f"{ name } {{{ axis } acc_dtype={ self .acc_dtype } }}"
2459
-
2460
2445
def L_op (self , inp , out , grads ):
2461
2446
(x ,) = inp
2462
2447
@@ -2526,7 +2511,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
2526
2511
pprint .assign (Sum , printing .FunctionPrinter (["sum" ], ["axis" ]))
2527
2512
2528
2513
2529
- class Prod (CAReduce ):
2514
+ class Prod (FixedOpCAReduce ):
2530
2515
"""
2531
2516
Multiplies all the values of a tensor along the specified axis(es).
2532
2517
@@ -2537,7 +2522,6 @@ class Prod(CAReduce):
2537
2522
"""
2538
2523
2539
2524
__props__ = ("scalar_op" , "axis" , "dtype" , "acc_dtype" , "no_zeros_in_input" )
2540
-
2541
2525
nfunc_spec = ("prod" , 1 , 1 )
2542
2526
2543
2527
def __init__ (self , axis = None , dtype = None , acc_dtype = None , no_zeros_in_input = False ):
@@ -2683,6 +2667,14 @@ def clone(self, **kwargs):
2683
2667
no_zeros_in_input = no_zeros_in_input ,
2684
2668
)
2685
2669
2670
+ def __str__ (self ):
2671
+ if self .no_zeros_in_input :
2672
+ return f"{ super ().__str__ ()[:- 1 ]} , no_zeros_in_input}})"
2673
+ return super ().__str__ ()
2674
+
2675
+ def __repr__ (self ):
2676
+ return f"{ super ().__repr__ ()[:- 1 ]} , no_zeros_in_input={ self .no_zeros_in_input } )"
2677
+
2686
2678
2687
2679
def prod (
2688
2680
input ,
@@ -2751,7 +2743,7 @@ def c_code_cache_version(self):
2751
2743
mul_without_zeros = MulWithoutZeros (aes .upcast_out , name = "mul_without_zeros" )
2752
2744
2753
2745
2754
- class ProdWithoutZeros (CAReduce ):
2746
+ class ProdWithoutZeros (FixedOpCAReduce ):
2755
2747
def __init__ (self , axis = None , dtype = None , acc_dtype = None ):
2756
2748
super ().__init__ (
2757
2749
mul_without_zeros ,
0 commit comments